Skip to content

Commit 4306c4f

Browse files
committed
address comments
1 parent 414e116 commit 4306c4f

File tree

2 files changed

+32
-46
lines changed

2 files changed

+32
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,7 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
937937
*/
938938
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
939939
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
940-
case e: CaseWhen if canCodeGen(e) =>
941-
e.toCodegen()
940+
case e: CaseWhen if canCodeGen(e) => e.toCodegen()
942941
}
943942

944943
private def canCodeGen(e: CaseWhen): Boolean = {

sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,14 @@
1717

1818
package org.apache.spark.sql.internal
1919

20-
import org.scalatest.BeforeAndAfterAll
21-
2220
import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
2321
import org.apache.spark.sql.execution.WholeStageCodegenExec
2422
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
2523

26-
class SQLConfSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll {
27-
import testImplicits._
28-
24+
class SQLConfSuite extends QueryTest with SharedSQLContext {
2925
private val testKey = "test.key.0"
3026
private val testVal = "test.val.0"
3127

32-
override def beforeAll() {
33-
super.beforeAll()
34-
sql("DROP TABLE IF EXISTS testData")
35-
spark
36-
.range(10)
37-
.select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd)
38-
.write
39-
.saveAsTable("testData")
40-
}
41-
42-
override def afterAll(): Unit = {
43-
try {
44-
sql("DROP TABLE IF EXISTS testData")
45-
} finally {
46-
super.afterAll()
47-
}
48-
}
49-
5028
test("propagate from spark conf") {
5129
// We create a new context here to avoid order dependence with other tests that might call
5230
// clear().
@@ -243,31 +221,40 @@ class SQLConfSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAl
243221
}
244222

245223
test("MAX_CASES_BRANCHES") {
224+
import testImplicits._
225+
246226
val original = spark.conf.get(SQLConf.MAX_CASES_BRANCHES)
247227
try {
248-
val sql_one_branch_caseWhen = "SELECT CASE WHEN a = 1 THEN 1 END FROM testData"
249-
val sql_two_branch_caseWhen = "SELECT CASE WHEN a = 1 THEN 1 ELSE 0 END FROM testData"
250-
251-
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, "0")
252-
assert(!sql(sql_one_branch_caseWhen)
253-
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
254-
assert(!sql(sql_two_branch_caseWhen)
255-
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
256-
257-
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, "1")
258-
assert(sql(sql_one_branch_caseWhen)
259-
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
260-
assert(!sql(sql_two_branch_caseWhen)
261-
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
262-
263-
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, "2")
264-
assert(sql(sql_one_branch_caseWhen)
265-
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
266-
assert(sql(sql_two_branch_caseWhen)
267-
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
228+
withTable("tab1") {
229+
spark
230+
.range(10)
231+
.select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd)
232+
.write
233+
.saveAsTable("tab1")
234+
235+
val sql_one_branch_caseWhen = "SELECT CASE WHEN a = 1 THEN 1 END FROM tab1"
236+
val sql_two_branch_caseWhen = "SELECT CASE WHEN a = 1 THEN 1 ELSE 0 END FROM tab1"
237+
238+
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, "0")
239+
assert(!sql(sql_one_branch_caseWhen)
240+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
241+
assert(!sql(sql_two_branch_caseWhen)
242+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
243+
244+
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, "1")
245+
assert(sql(sql_one_branch_caseWhen)
246+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
247+
assert(!sql(sql_two_branch_caseWhen)
248+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
249+
250+
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, "2")
251+
assert(sql(sql_one_branch_caseWhen)
252+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
253+
assert(sql(sql_two_branch_caseWhen)
254+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
255+
}
268256
} finally {
269257
spark.conf.set(SQLConf.MAX_CASES_BRANCHES.key, s"$original")
270258
}
271259
}
272-
273260
}

0 commit comments

Comments
 (0)