@@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._
2323import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
2424import org .apache .spark .sql .execution .joins .{BroadcastHashJoin , ShuffledHashJoin }
2525import org .apache .spark .sql .functions ._
26+ import org .apache .spark .sql .test .{SQLTestUtils , TestSQLContext }
2627import org .apache .spark .sql .test .TestSQLContext ._
2728import org .apache .spark .sql .test .TestSQLContext .implicits ._
2829import org .apache .spark .sql .test .TestSQLContext .planner ._
2930import org .apache .spark .sql .types ._
30- import org .apache .spark .sql .{Row , SQLConf , execution }
31+ import org .apache .spark .sql .{SQLContext , Row , SQLConf , execution }
3132
3233
33- class PlannerSuite extends SparkFunSuite {
34+ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
35+
36+ override def sqlContext : SQLContext = TestSQLContext
37+
3438 private def testPartialAggregationPlan (query : LogicalPlan ): Unit = {
3539 val plannedOption = HashAggregation (query).headOption.orElse(Aggregation (query).headOption)
3640 val planned =
@@ -159,40 +163,43 @@ class PlannerSuite extends SparkFunSuite {
159163 }
160164
161165 test(" PartitioningCollection" ) {
162- // First, we disable broadcast join.
163- val origThreshold = conf.autoBroadcastJoinThreshold
164- setConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD , 0 )
165-
166- testData.registerTempTable(" normal" )
167- testData.limit(10 ).registerTempTable(" small" )
168- testData.limit(3 ).registerTempTable(" tiny" )
169- var numExchanges = sql(
170- """
171- |SELECT *
172- |FROM
173- | normal JOIN small ON (normal.key = small.key)
174- | JOIN tiny ON (small.key = tiny.key)
175- """ .stripMargin).queryExecution.executedPlan.collect {
176- case exchange : Exchange => exchange
177- }.length
178-
179- assert(numExchanges === 3 )
180-
181- numExchanges = sql(
182- """
183- |SELECT *
184- |FROM
185- | normal JOIN small ON (normal.key = small.key)
186- | JOIN tiny ON (normal.key = tiny.key)
187- """ .stripMargin).queryExecution.executedPlan.collect {
188- case exchange : Exchange => exchange
189- }.length
190-
191- assert(numExchanges === 3 )
192-
193- setConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD , origThreshold)
194- dropTempTable(" normal" )
195- dropTempTable(" small" )
196- dropTempTable(" tiny" )
166+ withTempTable(" normal" , " small" , " tiny" ) {
167+ testData.registerTempTable(" normal" )
168+ testData.limit(10 ).registerTempTable(" small" )
169+ testData.limit(3 ).registerTempTable(" tiny" )
170+
171+ // Disable broadcast join
172+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
173+ {
174+ val numExchanges = sql(
175+ """
176+ |SELECT *
177+ |FROM
178+ | normal JOIN small ON (normal.key = small.key)
179+ | JOIN tiny ON (small.key = tiny.key)
180+ """ .stripMargin
181+ ).queryExecution.executedPlan.collect {
182+ case exchange : Exchange => exchange
183+ }.length
184+ assert(numExchanges === 3 )
185+ }
186+
187+ {
188+ // This second query joins on different keys:
189+ val numExchanges = sql(
190+ """
191+ |SELECT *
192+ |FROM
193+ | normal JOIN small ON (normal.key = small.key)
194+ | JOIN tiny ON (normal.key = tiny.key)
195+ """ .stripMargin
196+ ).queryExecution.executedPlan.collect {
197+ case exchange : Exchange => exchange
198+ }.length
199+ assert(numExchanges === 3 )
200+ }
201+
202+ }
203+ }
197204 }
198205}
0 commit comments