Skip to content

Commit cd8269b

Browse files
committed
Refactor test to use SQLTestUtils
1 parent 2963857 commit cd8269b

File tree

1 file changed

+44
-37
lines changed

1 file changed

+44
-37
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._
2323
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2424
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
2525
import org.apache.spark.sql.functions._
26+
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
2627
import org.apache.spark.sql.test.TestSQLContext._
2728
import org.apache.spark.sql.test.TestSQLContext.implicits._
2829
import org.apache.spark.sql.test.TestSQLContext.planner._
2930
import org.apache.spark.sql.types._
30-
import org.apache.spark.sql.{Row, SQLConf, execution}
31+
import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
3132

3233

33-
class PlannerSuite extends SparkFunSuite {
34+
class PlannerSuite extends SparkFunSuite with SQLTestUtils {
35+
36+
override def sqlContext: SQLContext = TestSQLContext
37+
3438
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
3539
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
3640
val planned =
@@ -159,40 +163,43 @@ class PlannerSuite extends SparkFunSuite {
159163
}
160164

161165
test("PartitioningCollection") {
162-
// First, we disable broadcast join.
163-
val origThreshold = conf.autoBroadcastJoinThreshold
164-
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 0)
165-
166-
testData.registerTempTable("normal")
167-
testData.limit(10).registerTempTable("small")
168-
testData.limit(3).registerTempTable("tiny")
169-
var numExchanges = sql(
170-
"""
171-
|SELECT *
172-
|FROM
173-
| normal JOIN small ON (normal.key = small.key)
174-
| JOIN tiny ON (small.key = tiny.key)
175-
""".stripMargin).queryExecution.executedPlan.collect {
176-
case exchange: Exchange => exchange
177-
}.length
178-
179-
assert(numExchanges === 3)
180-
181-
numExchanges = sql(
182-
"""
183-
|SELECT *
184-
|FROM
185-
| normal JOIN small ON (normal.key = small.key)
186-
| JOIN tiny ON (normal.key = tiny.key)
187-
""".stripMargin).queryExecution.executedPlan.collect {
188-
case exchange: Exchange => exchange
189-
}.length
190-
191-
assert(numExchanges === 3)
192-
193-
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
194-
dropTempTable("normal")
195-
dropTempTable("small")
196-
dropTempTable("tiny")
166+
withTempTable("normal", "small", "tiny") {
167+
testData.registerTempTable("normal")
168+
testData.limit(10).registerTempTable("small")
169+
testData.limit(3).registerTempTable("tiny")
170+
171+
// Disable broadcast join
172+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
173+
{
174+
val numExchanges = sql(
175+
"""
176+
|SELECT *
177+
|FROM
178+
| normal JOIN small ON (normal.key = small.key)
179+
| JOIN tiny ON (small.key = tiny.key)
180+
""".stripMargin
181+
).queryExecution.executedPlan.collect {
182+
case exchange: Exchange => exchange
183+
}.length
184+
assert(numExchanges === 3)
185+
}
186+
187+
{
188+
// This second query joins on different keys:
189+
val numExchanges = sql(
190+
"""
191+
|SELECT *
192+
|FROM
193+
| normal JOIN small ON (normal.key = small.key)
194+
| JOIN tiny ON (normal.key = tiny.key)
195+
""".stripMargin
196+
).queryExecution.executedPlan.collect {
197+
case exchange: Exchange => exchange
198+
}.length
199+
assert(numExchanges === 3)
200+
}
201+
202+
}
203+
}
197204
}
198205
}

0 commit comments

Comments
 (0)