diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 135c7576e529..c8bb15bd8cab 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -62,6 +62,10 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) +test_that("calling sparkRSQL.init returns existing SQL context", { + expect_equal(sparkRSQL.init(sc), sqlContext) +}) + test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index b3f134614c6b..67da7b808bf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -32,7 +32,7 @@ private[r] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) def createSQLContext(jsc: JavaSparkContext): SQLContext = { - new SQLContext(jsc) + SQLContext.getOrCreate(jsc.sc) } def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = {