@@ -19,14 +19,15 @@ package org.apache.spark.sql.execution
1919
2020import java .sql .{Timestamp , Date }
2121
22- import org .apache . spark . serializer . Serializer
23- import org . apache . spark .{ SparkEnv , SparkConf , ShuffleDependency , SparkContext }
22+ import org .scalatest .{ FunSuite , BeforeAndAfterAll }
23+
2424import org .apache .spark .rdd .ShuffledRDD
25+ import org .apache .spark .serializer .Serializer
26+ import org .apache .spark .ShuffleDependency
2527import org .apache .spark .sql .types ._
2628import org .apache .spark .sql .Row
27- import org .scalatest .{FunSuite , BeforeAndAfterAll }
28-
29- import org .apache .spark .sql .{MyDenseVectorUDT , SQLContext , QueryTest }
29+ import org .apache .spark .sql .test .TestSQLContext ._
30+ import org .apache .spark .sql .{MyDenseVectorUDT , QueryTest }
3031
3132class SparkSqlSerializer2DataTypeSuite extends FunSuite {
3233 // Make sure that we will not use serializer2 for unsupported data types.
@@ -67,18 +68,17 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite {
6768}
6869
6970abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
70-
71- @ transient var sparkContext : SparkContext = _
72- @ transient var sqlContext : SQLContext = _
73- // We may have an existing SparkEnv (e.g. the one used by TestSQLContext).
74- @ transient val existingSparkEnv = SparkEnv .get
7571 var allColumns : String = _
7672 val serializerClass : Class [Serializer ] =
7773 classOf [SparkSqlSerializer2 ].asInstanceOf [Class [Serializer ]]
74+ var numShufflePartitions : Int = _
75+ var useSerializer2 : Boolean = _
7876
7977 override def beforeAll (): Unit = {
80- sqlContext.sql(" set spark.sql.shuffle.partitions=5" )
81- sqlContext.sql(" set spark.sql.useSerializer2=true" )
78+ numShufflePartitions = conf.numShufflePartitions
79+ useSerializer2 = conf.useSqlSerializer2
80+
81+ sql(" set spark.sql.useSerializer2=true" )
8282
8383 val supportedTypes =
8484 Seq (StringType , BinaryType , NullType , BooleanType ,
@@ -112,18 +112,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
112112 new Timestamp (i))
113113 }
114114
115- sqlContext. createDataFrame(rdd, schema).registerTempTable(" shuffle" )
115+ createDataFrame(rdd, schema).registerTempTable(" shuffle" )
116116
117117 super .beforeAll()
118118 }
119119
120120 override def afterAll (): Unit = {
121- sqlContext.dropTempTable(" shuffle" )
122- sparkContext.stop()
123- sqlContext = null
124- sparkContext = null
125- // Set the existing SparkEnv back.
126- SparkEnv .set(existingSparkEnv)
121+ dropTempTable(" shuffle" )
122+ sql(s " set spark.sql.shuffle.partitions= $numShufflePartitions" )
123+ sql(s " set spark.sql.useSerializer2= $useSerializer2" )
127124 super .afterAll()
128125 }
129126
@@ -144,64 +141,40 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
144141 }
145142
146143 test(" key schema and value schema are not nulls" ) {
147- val df = sqlContext. sql(s " SELECT DISTINCT ${allColumns} FROM shuffle " )
144+ val df = sql(s " SELECT DISTINCT ${allColumns} FROM shuffle " )
148145 checkSerializer(df.queryExecution.executedPlan, serializerClass)
149146 checkAnswer(
150147 df,
151- sqlContext. table(" shuffle" ).collect())
148+ table(" shuffle" ).collect())
152149 }
153150
154151 test(" value schema is null" ) {
155- val df = sqlContext. sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
152+ val df = sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
156153 checkSerializer(df.queryExecution.executedPlan, serializerClass)
157154 assert(
158155 df.map(r => r.getString(0 )).collect().toSeq ===
159- sqlContext.table(" shuffle" ).select(" col0" ).map(r => r.getString(0 )).collect().sorted.toSeq)
156+ table(" shuffle" ).select(" col0" ).map(r => r.getString(0 )).collect().sorted.toSeq)
157+ }
158+ }
159+
160+ /** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
161+ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
162+ override def beforeAll (): Unit = {
163+ super .beforeAll()
164+ // Sort merge will not be triggered.
165+ sql(" set spark.sql.shuffle.partitions = 200" )
160166 }
161167
162168 test(" key schema is null" ) {
163169 val aggregations = allColumns.split(" ," ).map(c => s " COUNT( $c) " ).mkString(" ," )
164- val df = sqlContext. sql(s " SELECT $aggregations FROM shuffle " )
170+ val df = sql(s " SELECT $aggregations FROM shuffle " )
165171 checkSerializer(df.queryExecution.executedPlan, serializerClass)
166172 checkAnswer(
167173 df,
168174 Row (1000 , 1000 , 0 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 ))
169175 }
170176}
171177
172- /** Tests SparkSqlSerializer2 with hash based shuffle. */
173- class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite {
174- override def beforeAll (): Unit = {
175- val sparkConf =
176- new SparkConf ()
177- .set(" spark.driver.allowMultipleContexts" , " true" )
178- .set(" spark.sql.testkey" , " true" )
179- .set(" spark.shuffle.manager" , " hash" )
180-
181- sparkContext = new SparkContext (" local[2]" , " Serializer2SQLContext" , sparkConf)
182- sqlContext = new SQLContext (sparkContext)
183- super .beforeAll()
184- }
185- }
186-
187- /** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
188- class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
189- override def beforeAll (): Unit = {
190- // Since spark.sql.shuffle.partition is 5, we will not do sort merge when
191- // spark.shuffle.sort.bypassMergeThreshold is also 5.
192- val sparkConf =
193- new SparkConf ()
194- .set(" spark.driver.allowMultipleContexts" , " true" )
195- .set(" spark.sql.testkey" , " true" )
196- .set(" spark.shuffle.manager" , " sort" )
197- .set(" spark.shuffle.sort.bypassMergeThreshold" , " 5" )
198-
199- sparkContext = new SparkContext (" local[2]" , " Serializer2SQLContext" , sparkConf)
200- sqlContext = new SQLContext (sparkContext)
201- super .beforeAll()
202- }
203- }
204-
205178/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
206179class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
207180
@@ -210,15 +183,8 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
210183 classOf [SparkSqlSerializer ].asInstanceOf [Class [Serializer ]]
211184
212185 override def beforeAll (): Unit = {
213- val sparkConf =
214- new SparkConf ()
215- .set(" spark.driver.allowMultipleContexts" , " true" )
216- .set(" spark.sql.testkey" , " true" )
217- .set(" spark.shuffle.manager" , " sort" )
218- .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" ) // Always do sort merge.
219-
220- sparkContext = new SparkContext (" local[2]" , " Serializer2SQLContext" , sparkConf)
221- sqlContext = new SQLContext (sparkContext)
222186 super .beforeAll()
187+ // To trigger the sort merge.
188+ sql(" set spark.sql.shuffle.partitions = 201" )
223189 }
224190}
0 commit comments