Skip to content

Commit 625e51f

Browse files
committed
Fix NPE
1 parent 6e58919 commit 625e51f

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ import org.scalatest.concurrent.Eventually._
3131
import org.scalatest.concurrent.PatienceConfiguration.Timeout
3232
import org.scalatest.time.SpanSugar._
3333

34+
import org.apache.spark.SparkContext
3435
import org.apache.spark.sql.ForeachWriter
3536
import org.apache.spark.sql.execution.streaming._
3637
import org.apache.spark.sql.functions.{count, window}
3738
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
38-
import org.apache.spark.sql.test.SharedSQLContext
39+
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
3940

4041
abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
4142

@@ -810,6 +811,11 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
810811

811812
private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}"
812813

814+
override def createSparkSession(): TestSparkSession = {
815+
new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context",
816+
sparkConf.set("spark.sql.testkey", "true")))
817+
}
818+
813819
override def beforeAll(): Unit = {
814820
super.beforeAll()
815821
testUtils = new KafkaTestUtils {

sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,18 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach {
4848
*/
4949
protected implicit def sqlContext: SQLContext = _spark.sqlContext
5050

51+
protected def createSparkSession: TestSparkSession = {
52+
new TestSparkSession(
53+
sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
54+
}
55+
5156
/**
5257
* Initialize the [[TestSparkSession]].
5358
*/
5459
protected override def beforeAll(): Unit = {
5560
SparkSession.sqlListener.set(null)
5661
if (_spark == null) {
57-
_spark = new TestSparkSession(
58-
sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
62+
_spark = createSparkSession
5963
}
6064
// Ensure we have initialized the context before calling parent code
6165
super.beforeAll()

0 commit comments

Comments
 (0)