Skip to content

Commit 29b1f6b

Browse files
committed
[SPARK-21256][SQL] Add withSQLConf to Catalyst Test
### What changes were proposed in this pull request? SQLConf is moved to Catalyst. We are adding more and more test cases for verifying the conf-specific behaviors. It is nice to add a helper function to simplify the test cases. ### How was this patch tested? N/A Author: gatorsmile <[email protected]> Closes #18469 from gatorsmile/withSQLConf.
1 parent d492cc5 commit 29b1f6b

File tree

9 files changed

+64
-63
lines changed

9 files changed

+64
-63
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
206206
}
207207

208208
test("No inferred filter when constraint propagation is disabled") {
209-
try {
210-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
209+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
211210
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
212211
val optimized = Optimize.execute(originalQuery)
213212
comparePlans(optimized, originalQuery)
214-
} finally {
215-
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
216213
}
217214
}
218215
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,7 @@ class OuterJoinEliminationSuite extends PlanTest {
234234
}
235235

236236
test("no outer join elimination if constraint propagation is disabled") {
237-
try {
238-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
239-
237+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
240238
val x = testRelation.subquery('x)
241239
val y = testRelation1.subquery('y)
242240

@@ -251,8 +249,6 @@ class OuterJoinEliminationSuite extends PlanTest {
251249
val optimized = Optimize.execute(originalQuery.analyze)
252250

253251
comparePlans(optimized, originalQuery.analyze)
254-
} finally {
255-
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
256252
}
257253
}
258254
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
2727
import org.apache.spark.sql.internal.SQLConf
28-
import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED
2928

3029
class PruneFiltersSuite extends PlanTest {
3130

@@ -149,8 +148,7 @@ class PruneFiltersSuite extends PlanTest {
149148
("tr1.a".attr > 10 || "tr1.c".attr < 10) &&
150149
'd.attr < 100)
151150

152-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
153-
try {
151+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
154152
val optimized = Optimize.execute(queryWithUselessFilter.analyze)
155153
// When constraint propagation is disabled, the useless filter won't be pruned.
156154
// It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant
@@ -160,8 +158,6 @@ class PruneFiltersSuite extends PlanTest {
160158
.join(tr2.where('d.attr < 100).where('d.attr < 100),
161159
Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze
162160
comparePlans(optimized, correctAnswer)
163-
} finally {
164-
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
165161
}
166162
}
167163
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType}
3030

31-
class ConstraintPropagationSuite extends SparkFunSuite {
31+
class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
3232

3333
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
3434
resolveColumn(tr.analyze, columnName)
@@ -400,26 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite {
400400
}
401401

402402
test("enable/disable constraint propagation") {
403-
try {
404-
val tr = LocalRelation('a.int, 'b.string, 'c.int)
405-
val filterRelation = tr.where('a.attr > 10)
403+
val tr = LocalRelation('a.int, 'b.string, 'c.int)
404+
val filterRelation = tr.where('a.attr > 10)
406405

407-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true)
406+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") {
408407
assert(filterRelation.analyze.constraints.nonEmpty)
408+
}
409409

410-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
410+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
411411
assert(filterRelation.analyze.constraints.isEmpty)
412+
}
412413

413-
val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
414-
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3)
414+
val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
415+
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3)
415416

416-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true)
417+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") {
417418
assert(aliasedRelation.analyze.constraints.nonEmpty)
419+
}
418420

419-
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
421+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
420422
assert(aliasedRelation.analyze.constraints.isEmpty)
421-
} finally {
422-
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
423423
}
424424
}
425425
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.AnalysisException
2122
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -28,8 +29,9 @@ import org.apache.spark.sql.internal.SQLConf
2829
/**
2930
* Provides helper methods for comparing plans.
3031
*/
31-
abstract class PlanTest extends SparkFunSuite with PredicateHelper {
32+
trait PlanTest extends SparkFunSuite with PredicateHelper {
3233

34+
// TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules
3335
protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)
3436

3537
/**
@@ -142,4 +144,32 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
142144
plan1 == plan2
143145
}
144146
}
147+
148+
/**
149+
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
150+
* configurations.
151+
*/
152+
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
153+
val conf = SQLConf.get
154+
val (keys, values) = pairs.unzip
155+
val currentValues = keys.map { key =>
156+
if (conf.contains(key)) {
157+
Some(conf.getConfString(key))
158+
} else {
159+
None
160+
}
161+
}
162+
(keys, values).zipped.foreach { (k, v) =>
163+
if (SQLConf.staticConfKeys.contains(k)) {
164+
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
165+
}
166+
conf.setConfString(k, v)
167+
}
168+
try f finally {
169+
keys.zip(currentValues).foreach {
170+
case (key, Some(value)) => conf.setConfString(key, value)
171+
case (key, None) => conf.unsetConf(key)
172+
}
173+
}
174+
}
145175
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.statsEstimation
1919

2020
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
2121
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
22+
import org.apache.spark.sql.catalyst.plans.PlanTest
2223
import org.apache.spark.sql.catalyst.plans.logical._
2324
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
2425
import org.apache.spark.sql.internal.SQLConf
2526

2627

27-
class AggregateEstimationSuite extends StatsEstimationTestBase {
28+
class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest {
2829

2930
/** Columns for testing */
3031
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
@@ -100,9 +101,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
100101
size = Some(4 * (8 + 4)),
101102
attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
102103

103-
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
104-
try {
105-
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
104+
withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
106105
val noGroupAgg = Aggregate(groupingExpressions = Nil,
107106
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
108107
assert(noGroupAgg.stats ==
@@ -114,8 +113,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
114113
assert(hasGroupAgg.stats ==
115114
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
116115
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
117-
} finally {
118-
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
119116
}
120117
}
121118

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
package org.apache.spark.sql.catalyst.statsEstimation
1919

2020
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
21+
import org.apache.spark.sql.catalyst.plans.PlanTest
2122
import org.apache.spark.sql.catalyst.plans.logical._
2223
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.types.IntegerType
2425

2526

26-
class BasicStatsEstimationSuite extends StatsEstimationTestBase {
27+
class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
2728
val attribute = attr("key")
2829
val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
2930
nullCount = 0, avgLen = 4, maxLen = 4)
@@ -82,18 +83,15 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
8283
plan: LogicalPlan,
8384
expectedStatsCboOn: Statistics,
8485
expectedStatsCboOff: Statistics): Unit = {
85-
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
86-
try {
86+
withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
8787
// Invalidate statistics
8888
plan.invalidateStatsCache()
89-
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
9089
assert(plan.stats == expectedStatsCboOn)
90+
}
9191

92+
withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
9293
plan.invalidateStatsCache()
93-
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
9494
assert(plan.stats == expectedStatsCboOff)
95-
} finally {
96-
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
9795
}
9896
}
9997

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
21+
import org.apache.spark.sql.internal.SQLConf
2122

2223
/**
2324
* Test cases for the builder pattern of [[SparkSession]].
@@ -67,6 +68,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
6768
assert(activeSession != defaultSession)
6869
assert(session == activeSession)
6970
assert(session.conf.get("spark-config2") == "a")
71+
assert(session.sessionState.conf == SQLConf.get)
72+
assert(SQLConf.get.getConfString("spark-config2") == "a")
7073
SparkSession.clearActiveSession()
7174

7275
assert(SparkSession.builder().getOrCreate() == defaultSession)

sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ import org.apache.spark.sql._
3535
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
3636
import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
3737
import org.apache.spark.sql.catalyst.FunctionIdentifier
38+
import org.apache.spark.sql.catalyst.plans.PlanTest
3839
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3940
import org.apache.spark.sql.catalyst.util._
4041
import org.apache.spark.sql.execution.FilterExec
42+
import org.apache.spark.sql.internal.SQLConf
4143
import org.apache.spark.util.{UninterruptibleThread, Utils}
4244

4345
/**
@@ -53,7 +55,8 @@ import org.apache.spark.util.{UninterruptibleThread, Utils}
5355
private[sql] trait SQLTestUtils
5456
extends SparkFunSuite with Eventually
5557
with BeforeAndAfterAll
56-
with SQLTestData { self =>
58+
with SQLTestData
59+
with PlanTest { self =>
5760

5861
protected def sparkContext = spark.sparkContext
5962

@@ -89,28 +92,9 @@ private[sql] trait SQLTestUtils
8992
}
9093
}
9194

92-
/**
93-
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
94-
* configurations.
95-
*
96-
* @todo Probably this method should be moved to a more general place
97-
*/
98-
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
99-
val (keys, values) = pairs.unzip
100-
val currentValues = keys.map { key =>
101-
if (spark.conf.contains(key)) {
102-
Some(spark.conf.get(key))
103-
} else {
104-
None
105-
}
106-
}
107-
(keys, values).zipped.foreach(spark.conf.set)
108-
try f finally {
109-
keys.zip(currentValues).foreach {
110-
case (key, Some(value)) => spark.conf.set(key, value)
111-
case (key, None) => spark.conf.unset(key)
112-
}
113-
}
95+
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
96+
SparkSession.setActiveSession(spark)
97+
super.withSQLConf(pairs: _*)(f)
11498
}
11599

116100
/**

0 commit comments

Comments
 (0)