@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
2020import java .sql .{Timestamp , Date }
2121
2222import org .apache .spark .serializer .Serializer
23- import org .apache .spark .{SparkConf , ShuffleDependency , SparkContext }
23+ import org .apache .spark .{SparkEnv , SparkConf , ShuffleDependency , SparkContext }
2424import org .apache .spark .rdd .ShuffledRDD
2525import org .apache .spark .sql .types ._
2626import org .apache .spark .sql .Row
@@ -70,6 +70,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
7070
7171 @ transient var sparkContext : SparkContext = _
7272 @ transient var sqlContext : SQLContext = _
73+ // We may have an existing SparkEnv (e.g. the one used by TestSQLContext).
74+ @ transient val existingSparkEnv = SparkEnv .get
7375 var allColumns : String = _
7476 val serializerClass : Class [Serializer ] =
7577 classOf [SparkSqlSerializer2 ].asInstanceOf [Class [Serializer ]]
@@ -118,6 +120,10 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
118120 override def afterAll (): Unit = {
119121 sqlContext.dropTempTable(" shuffle" )
120122 sparkContext.stop()
123+ sqlContext = null
124+ sparkContext = null
125+ // Set the existing SparkEnv back.
126+ SparkEnv .set(existingSparkEnv)
121127 super .afterAll()
122128 }
123129
@@ -168,6 +174,7 @@ class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite {
168174 override def beforeAll (): Unit = {
169175 val sparkConf =
170176 new SparkConf ()
177+ .set(" spark.driver.allowMultipleContexts" , " true" )
171178 .set(" spark.sql.testkey" , " true" )
172179 .set(" spark.shuffle.manager" , " hash" )
173180
@@ -184,6 +191,7 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
184191 // spark.shuffle.sort.bypassMergeThreshold is also 5.
185192 val sparkConf =
186193 new SparkConf ()
194+ .set(" spark.driver.allowMultipleContexts" , " true" )
187195 .set(" spark.sql.testkey" , " true" )
188196 .set(" spark.shuffle.manager" , " sort" )
189197 .set(" spark.shuffle.sort.bypassMergeThreshold" , " 5" )
@@ -204,6 +212,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
204212 override def beforeAll (): Unit = {
205213 val sparkConf =
206214 new SparkConf ()
215+ .set(" spark.driver.allowMultipleContexts" , " true" )
207216 .set(" spark.sql.testkey" , " true" )
208217 .set(" spark.shuffle.manager" , " sort" )
209218 .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" ) // Always do sort merge.
0 commit comments