From 164629ac7ce9d75434746818673b2999cf204f9c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 30 May 2016 12:26:30 -0700 Subject: [PATCH 01/10] [WIP] SparkSession in R --- R/pkg/R/sparkR.R | 69 ++++++++++++++++--- .../org/apache/spark/sql/api/r/SQLUtils.scala | 11 ++- 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 04a8b1e1f3952..e8b03923af395 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -114,6 +114,19 @@ sparkR.init <- function( sparkExecutorEnv = list(), sparkJars = "", sparkPackages = "") { + .Deprecated("sparkR.session.getOrCreate") + sparkR.sparkContext(master, appName, sparkHome, sparkEnvir, sparkExecutorEnv, sparkJars, + sparkPackages) +} + +sparkR.sparkContext <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { cat(paste("Re-using existing Spark Context.", @@ -239,21 +252,22 @@ sparkR.init <- function( #'} sparkRSQL.init <- function(jsc = NULL) { - if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - return(get(".sparkRSQLsc", envir = .sparkREnv)) + .Deprecated("sparkR.session.getOrCreate") + + if (exists(".sparkRsession", envir = .sparkREnv)) { + return(get(".sparkRsession", envir = .sparkREnv)) } # If jsc is NULL, create a Spark Context sc <- if (is.null(jsc)) { - sparkR.init() + sparkR.sparkContext() } else { jsc } - sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createSQLContext", sc) - assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) + assign(".sparkRsession", sqlContext, envir = .sparkREnv) sqlContext } @@ -270,13 +284,15 @@ sparkRSQL.init <- function(jsc = NULL) { #'} sparkRHive.init <- function(jsc = NULL) { - if (exists(".sparkRHivesc", envir = .sparkREnv)) { - return(get(".sparkRHivesc", envir = .sparkREnv)) + .Deprecated("sparkR.session.getOrCreate") + + if (exists(".sparkRsession", envir = .sparkREnv)) { + return(get(".sparkRsession", envir = .sparkREnv)) } # If jsc is NULL, create a Spark Context sc <- if (is.null(jsc)) { - sparkR.init() + sparkR.sparkContext() } else { jsc } @@ -289,10 +305,45 @@ sparkRHive.init <- function(jsc = NULL) { stop("Spark SQL is not built with Hive support") }) - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + assign(".sparkRsession", hiveCtx, envir = .sparkREnv) hiveCtx } +#' Get the existing SparkSession or initialize a new SparkSession. +#' +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session.getOrCreate() +#' df <- read.json(path) +#'} +#' @note since 2.0.0 + +sparkR.session.getOrCreate <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkConfig = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkPackages = "", + ...) { + + if (!exists(".sparkRjsc", envir = .sparkREnv)) { + sparkR.sparkContext(master, appName, sparkHome, sparkEnvir, sparkExecutorEnv, sparkJars, + sparkPackages) + } + + if (exists(".sparkRsession", envir = .sparkREnv)) { + sparkSession <- get(".sparkRsession", envir = .sparkREnv) + # TODO: apply config + } else { + sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getOrCreateSparkSession") + assign(".sparkRsession", sparkSession, envir = .sparkREnv) + } + sparkSession +} + #' Assigns a group ID to all the jobs started by this thread until the group ID is set to a #' different value or cleared. #' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index fe426fa3c7e8a..18260cfba7439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -26,13 +26,22 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, SaveMode, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.types._ private[sql] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) + def getOrCreateSparkSession(jsc: JavaSparkContext): SparkSession = { + if (SparkSession.hiveClassesArePresent) { + SparkSession.builder().sparkContext(HiveUtils.withHiveExternalCatalog(jsc.sc)).getOrCreate() + } else { + SparkSession.builder().sparkContext(jsc.sc).getOrCreate() + } + } + def createSQLContext(jsc: JavaSparkContext): SQLContext = { SQLContext.getOrCreate(jsc.sc) } From 8e031cad0f72aaf6a6e8e61967f8466f476697af Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 9 Jun 2016 08:25:14 -0700 Subject: [PATCH 02/10] more changes for spark session --- R/pkg/NAMESPACE | 8 +- R/pkg/R/DataFrame.R | 8 +- R/pkg/R/SQLContext.R | 107 +++++++++--------- R/pkg/R/backend.R | 2 +- R/pkg/R/sparkR.R | 86 +++++++------- R/pkg/inst/profile/shell.R | 8 +- R/pkg/inst/tests/testthat/jarTest.R | 5 +- R/pkg/inst/tests/testthat/packageInAJarTest.R | 4 +- R/pkg/inst/tests/testthat/test_Serde.R | 4 +- R/pkg/inst/tests/testthat/test_binaryFile.R | 5 +- .../tests/testthat/test_binary_function.R | 5 +- R/pkg/inst/tests/testthat/test_broadcast.R | 5 +- R/pkg/inst/tests/testthat/test_context.R | 43 +++---- .../inst/tests/testthat/test_includePackage.R | 5 +- R/pkg/inst/tests/testthat/test_mllib.R | 7 +- .../tests/testthat/test_parallelize_collect.R | 5 +- R/pkg/inst/tests/testthat/test_rdd.R | 5 +- R/pkg/inst/tests/testthat/test_shuffle.R | 5 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 23 +++- R/pkg/inst/tests/testthat/test_take.R | 19 ++-- R/pkg/inst/tests/testthat/test_textFile.R | 5 +- R/pkg/inst/tests/testthat/test_utils.R | 5 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 53 ++++++--- 23 files changed, 246 insertions(+), 176 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9412ec3f9e09b..835181c2fe393 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -6,10 +6,15 @@ importFrom(methods, setGeneric, setMethod, setOldClass) #useDynLib(SparkR, stringHashCode) # S3 methods exported +export("sparkR.session.getOrCreate") export("sparkR.init") export("sparkR.stop") +export("sparkR.session.stop") export("print.jobj") +export("sparkRSQL.init", + "sparkRHive.init") + # MLlib integration exportMethods("glm", "spark.glm", @@ -287,9 +292,6 @@ exportMethods("%in%", exportClasses("GroupedData") exportMethods("agg") -export("sparkRSQL.init", - "sparkRHive.init") - export("as.DataFrame", "cacheTable", "clearCache", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 4e044565f4954..ea091c81016d4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2333,9 +2333,7 @@ setMethod("write.df", signature(df = "SparkDataFrame", path = "character"), function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { - sqlContext <- getSqlContext() - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) @@ -2393,9 +2391,7 @@ setMethod("saveAsTable", signature(df = "SparkDataFrame", tableName = "character"), function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { - sqlContext <- getSqlContext() - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 914b02a47ad67..7d732b35b223f 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -53,7 +53,8 @@ dispatchFunc <- function(newFuncSig, x, ...) { # Strip sqlContext from list of parameters and then pass the rest along. contextNames <- c("org.apache.spark.sql.SQLContext", "org.apache.spark.sql.hive.HiveContext", - "org.apache.spark.sql.hive.test.TestHiveContext") + "org.apache.spark.sql.hive.test.TestHiveContext", + "org.apache.spark.sql.SparkSession") if (missing(x) && length(list(...)) == 0) { f() } else if (class(x) == "jobj" && @@ -65,14 +66,12 @@ dispatchFunc <- function(newFuncSig, x, ...) { } } -#' return the SQL Context -getSqlContext <- function() { - if (exists(".sparkRHivesc", envir = .sparkREnv)) { - get(".sparkRHivesc", envir = .sparkREnv) - } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - get(".sparkRSQLsc", envir = .sparkREnv) +#' return the SparkSession +getSparkSession <- function() { + if (exists(".sparkRsession", envir = .sparkREnv)) { + get(".sparkRsession", envir = .sparkREnv) } else { - stop("SQL context not initialized") + stop("SparkSession not initialized") } } @@ -109,6 +108,13 @@ infer_type <- function(x) { } } +getDefaultSqlSource <- function() { + sparkSession <- getSparkSession() + conf <- callJMethod(sparkSession, "conf") + source <- callJMethod(conf, "get", "spark.sql.sources.default", "org.apache.spark.sql.parquet") + source +} + #' Create a SparkDataFrame #' #' Converts R data.frame or list into SparkDataFrame. @@ -131,7 +137,7 @@ infer_type <- function(x) { # TODO(davies): support sampling and infer type from NA createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() if (is.data.frame(data)) { # get the names of columns, they will be put into RDD if (is.null(schema)) { @@ -158,7 +164,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { data <- do.call(mapply, append(args, data)) } if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -201,7 +207,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlContext) + srdd, schema$jobj, sparkSession) dataFrame(sdf) } @@ -265,10 +271,10 @@ setMethod("toDF", signature(x = "RDD"), #' @method read.json default read.json.default <- function(path) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } @@ -336,10 +342,10 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @method read.parquet default read.parquet.default <- function(path) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -385,10 +391,10 @@ parquetFile <- function(x, ...) { #' @method read.text default read.text.default <- function(path) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } @@ -418,8 +424,8 @@ read.text <- function(x, ...) { #' @method sql default sql.default <- function(sqlQuery) { - sqlContext <- getSqlContext() - sdf <- callJMethod(sqlContext, "sql", sqlQuery) + sparkSession <- getSparkSession() + sdf <- callJMethod(sparkSession, "sql", sqlQuery) dataFrame(sdf) } @@ -449,8 +455,8 @@ sql <- function(x, ...) { #' @note since 2.0.0 tableToDF <- function(tableName) { - sqlContext <- getSqlContext() - sdf <- callJMethod(sqlContext, "table", tableName) + sparkSession <- getSparkSession() + sdf <- callJMethod(sparkSession, "table", tableName) dataFrame(sdf) } @@ -472,12 +478,8 @@ tableToDF <- function(tableName) { #' @method tables default tables.default <- function(databaseName = NULL) { - sqlContext <- getSqlContext() - jdf <- if (is.null(databaseName)) { - callJMethod(sqlContext, "tables") - } else { - callJMethod(sqlContext, "tables", databaseName) - } + sparkSession <- getSparkSession() + jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) dataFrame(jdf) } @@ -503,12 +505,11 @@ tables <- function(x, ...) { #' @method tableNames default tableNames.default <- function(databaseName = NULL) { - sqlContext <- getSqlContext() - if (is.null(databaseName)) { - callJMethod(sqlContext, "tableNames") - } else { - callJMethod(sqlContext, "tableNames", databaseName) - } + sparkSession <- getSparkSession() + jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) } tableNames <- function(x, ...) { @@ -536,8 +537,9 @@ tableNames <- function(x, ...) { #' @method cacheTable default cacheTable.default <- function(tableName) { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "cacheTable", tableName) + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "cacheTable", tableName) } cacheTable <- function(x, ...) { @@ -565,8 +567,9 @@ cacheTable <- function(x, ...) { #' @method uncacheTable default uncacheTable.default <- function(tableName) { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "uncacheTable", tableName) + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "uncacheTable", tableName) } uncacheTable <- function(x, ...) { @@ -587,8 +590,9 @@ uncacheTable <- function(x, ...) { #' @method clearCache default clearCache.default <- function() { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "clearCache") + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "clearCache") } clearCache <- function() { @@ -615,11 +619,12 @@ clearCache <- function() { #' @method dropTempTable default dropTempTable.default <- function(tableName) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlContext, "dropTempTable", tableName) + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", tableName) } dropTempTable <- function(x, ...) { @@ -655,21 +660,20 @@ dropTempTable <- function(x, ...) { #' @method read.df default read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } if (is.null(source)) { - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() } if (!is.null(schema)) { stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, options) } dataFrame(sdf) } @@ -715,12 +719,13 @@ loadDF <- function(x, ...) { #' @method createExternalTable default createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } - sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) + catalog <- callJMethod(sparkSession, "catalog") + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) dataFrame(sdf) } @@ -768,11 +773,11 @@ read.jdbc <- function(url, tableName, numPartitions = 0L, predicates = list(), ...) { jprops <- varargsToJProperties(...) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") if (!is.null(partitionColumn)) { if (is.null(numPartitions) || numPartitions == 0) { - sqlContext <- getSqlContext() - sc <- callJMethod(sqlContext, "sparkContext") + sparkSession <- getSparkSession() + sc <- callJMethod(sparkSession, "sparkContext") numPartitions <- callJMethod(sc, "defaultParallelism") } else { numPartitions <- numToInt(numPartitions) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 6c81492f8b675..860511d48d0d3 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -68,7 +68,7 @@ isRemoveMethod <- function(isStatic, objId, methodName) { # methodName - name of method to be invoked invokeJava <- function(isStatic, objId, methodName, ...) { if (!exists(".sparkRCon", .sparkREnv)) { - stop("No connection to backend found. Please re-run sparkR.init") + stop("No connection to backend found. Please re-run sparkR.session.getOrCreate()") } # If this isn't a removeJObject call diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index e8b03923af395..7766fa0b8a443 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -31,7 +31,18 @@ connExists <- function(env) { #' Stop the Spark context. #' #' Also terminates the backend this R session is connected to +#' @export sparkR.stop <- function() { + .Deprecated("sparkR.session.stop") + sparkR.session.stop() +} + +#' Stop the Spark Session and Spark Context. +#' +#' Also terminates the backend this R session is connected to. +#' @export +#' @note since 2.0.0 +sparkR.session.stop <- function() { env <- .sparkREnv if (exists(".sparkRCon", envir = env)) { if (exists(".sparkRjsc", envir = env)) { @@ -39,12 +50,8 @@ sparkR.stop <- function() { callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) - if (exists(".sparkRSQLsc", envir = env)) { - rm(".sparkRSQLsc", envir = env) - } - - if (exists(".sparkRHivesc", envir = env)) { - rm(".sparkRHivesc", envir = env) + if (exists(".sparkRsession", envir = env)) { + rm(".sparkRsession", envir = env) } } @@ -119,6 +126,7 @@ sparkR.init <- function( sparkPackages) } +# Internal function to handle creating the SparkContext. sparkR.sparkContext <- function( master = "", appName = "SparkR", @@ -130,7 +138,7 @@ sparkR.sparkContext <- function( if (exists(".sparkRjsc", envir = .sparkREnv)) { cat(paste("Re-using existing Spark Context.", - "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) + "Please stop SparkR with sparkR.session.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } @@ -243,6 +251,9 @@ sparkR.sparkContext <- function( #' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext #' +#' Starting SparkR 2.0, a SparkSession is initialized and returned instead. +#' This API is deprecated and kept for backward compatibility only. +#' #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @export #' @examples @@ -258,17 +269,8 @@ sparkRSQL.init <- function(jsc = NULL) { return(get(".sparkRsession", envir = .sparkREnv)) } - # If jsc is NULL, create a Spark Context - sc <- if (is.null(jsc)) { - sparkR.sparkContext() - } else { - jsc - } - sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - sc) - assign(".sparkRsession", sqlContext, envir = .sparkREnv) - sqlContext + # Default to without Hive support for backward compatibility. + sparkR.session.getOrCreate(enableHiveSupport = FALSE) } #' Initialize a new HiveContext. @@ -290,32 +292,34 @@ sparkRHive.init <- function(jsc = NULL) { return(get(".sparkRsession", envir = .sparkREnv)) } - # If jsc is NULL, create a Spark Context - sc <- if (is.null(jsc)) { - sparkR.sparkContext() - } else { - jsc - } - - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, - error = function(err) { - stop("Spark SQL is not built with Hive support") - }) - - assign(".sparkRsession", hiveCtx, envir = .sparkREnv) - hiveCtx + # Default to without Hive support for backward compatibility. + sparkR.session.getOrCreate(enableHiveSupport = TRUE) } #' Get the existing SparkSession or initialize a new SparkSession. #' +#' @param master The Spark master URL +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkConfig Named list of Spark configuration to set on worker nodes +#' @param sparkExecutorConfig Named list of Spark configuration to be used when launching executors +#' @param sparkJars Character vector of jar files to pass to the worker nodes +#' @param sparkPackages Character vector of packages from spark-packages.org +#' @param enableHiveSupport Enable support for Hive #' @export #' @examples #'\dontrun{ #' sparkR.session.getOrCreate() #' df <- read.json(path) +#' +#' sparkR.session.getOrCreate("local[2]", "SparkR", "/home/spark") +#' sparkR.session.getOrCreate("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sparkR.session.getOrCreate("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="4g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1")) #'} #' @note since 2.0.0 @@ -324,21 +328,27 @@ sparkR.session.getOrCreate <- function( appName = "SparkR", sparkHome = Sys.getenv("SPARK_HOME"), sparkConfig = list(), - sparkExecutorEnv = list(), + sparkExecutorConfig = list(), sparkJars = "", sparkPackages = "", + enableHiveSupport = TRUE, ...) { if (!exists(".sparkRjsc", envir = .sparkREnv)) { - sparkR.sparkContext(master, appName, sparkHome, sparkEnvir, sparkExecutorEnv, sparkJars, + sparkR.sparkContext(master, appName, sparkHome, sparkConfig, sparkExecutorConfig, sparkJars, sparkPackages) + stopifnot(exists(".sparkRjsc", envir = .sparkREnv)) } if (exists(".sparkRsession", envir = .sparkREnv)) { sparkSession <- get(".sparkRsession", envir = .sparkREnv) # TODO: apply config } else { - sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getOrCreateSparkSession") + jsc <- get(".sparkRjsc", envir = .sparkREnv) + sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getOrCreateSparkSession", + jsc, + enableHiveSupport) assign(".sparkRsession", sparkSession, envir = .sparkREnv) } sparkSession diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 90a3761e41f82..d1bc6c43b2d96 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -24,11 +24,11 @@ old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init() + spark <- SparkR::sparkR.session.getOrCreate() + assign("spark", spark, envir=.GlobalEnv) + sc <- SparkR:::callJMethod(spark, "sparkContext") assign("sc", sc, envir=.GlobalEnv) - sqlContext <- SparkR::sparkRSQL.init(sc) sparkVer <- SparkR:::callJMethod(sc, "version") - assign("sqlContext", sqlContext, envir=.GlobalEnv) cat("\n Welcome to") cat("\n") cat(" ____ __", "\n") @@ -43,5 +43,5 @@ cat(" /_/", "\n") cat("\n") - cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") + cat("\n SparkSession available as 'spark'.\n") } diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index d68bb20950b00..a56cf81c82bfc 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,7 +16,8 @@ # library(SparkR) -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) helloTest <- SparkR:::callJStatic("sparkR.test.hello", "helloWorld", @@ -27,6 +28,6 @@ basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", 2L, 2L) -sparkR.stop() +sparkR.session.stop() output <- c(helloTest, basicFunction) writeLines(output) diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index c26b28b78dee8..746f60b857e45 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,13 +17,13 @@ library(SparkR) library(sparkPackageTest) -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() run1 <- myfunc(5L) run2 <- myfunc(-4L) -sparkR.stop() +sparkR.session.stop() if (run1 != 6) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index dddce54d70443..ce7a593d7f62a 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) @@ -75,3 +75,5 @@ test_that("SerDe of list of lists", { y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 976a7558a816d..03ee52bea03cf 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,7 +18,8 @@ context("functions on binary files") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") @@ -87,3 +88,5 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 7bad4d2a7e106..acd42e4ea7322 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,8 @@ context("binary functions") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data nums <- 1:10 @@ -99,3 +100,5 @@ test_that("zipPartitions() on RDDs", { unlink(fileName) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 8be6efc3dbed3..c0fe9d1315139 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,8 @@ context("broadcast variables") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data nums <- 1:2 @@ -46,3 +47,5 @@ test_that("without using broadcast variable", { expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) expect_equal(actual, expected) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 126484c995fb3..a143e14f2e3c2 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -54,33 +54,23 @@ test_that("Check masked functions", { sort(namesOfMaskedCompletely, na.last = TRUE)) }) -test_that("repeatedly starting and stopping SparkR", { - for (i in 1:4) { - sc <- sparkR.init() - rdd <- parallelize(sc, 1:20, 2L) - expect_equal(count(rdd), 20) - sparkR.stop() - } -}) - -test_that("repeatedly starting and stopping SparkR SQL", { - for (i in 1:4) { - sc <- sparkR.init() - sqlContext <- sparkRSQL.init(sc) - df <- createDataFrame(data.frame(a = 1:20)) - expect_equal(count(df), 20) - sparkR.stop() - } -}) +# test_that("repeatedly starting and stopping SparkR", { +# for (i in 1:4) { +# sparkR.session.getOrCreate() +# df <- createDataFrame(data.frame(dummy=1:i)) +# expect_equal(count(df), i) +# sparkR.session.stop() +# Sys.sleep(5) # Need more time to shutdown Hive metastore +# } +# }) test_that("rdd GC across sparkR.stop", { - sparkR.stop() - sc <- sparkR.init() # sc should get id 0 + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 - sparkR.stop() + sparkR.session.stop() - sc <- sparkR.init() # sc should get id 0 again + sc <- sparkR.sparkContext() # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -97,15 +87,17 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - sc <- sparkR.init() + sc <- sparkR.sparkContext() setJobGroup(sc, "groupId", "job description", TRUE) cancelJobGroup(sc, "groupId") clearJobGroup(sc) + sparkR.session.stop() }) test_that("utility function can be called", { - sc <- sparkR.init() + sc <- sparkR.sparkContext() setLogLevel(sc, "ERROR") + sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { @@ -156,7 +148,8 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sc <- sparkR.init() + sc <- sparkR.sparkContext() doubled <- spark.lapply(sc, 1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) + sparkR.session.stop() }) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 8152b448d0870..697a9de44c6e8 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,8 @@ context("include R packages") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data nums <- 1:2 @@ -55,3 +56,5 @@ test_that("use include package", { actual <- collect(data) } }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 59ef15c1e9fd5..9efd26e28b3f0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -20,10 +20,7 @@ library(testthat) context("MLlib functions") # Tests for MLlib functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) +sparkSession <- sparkR.session.getOrCreate() test_that("formula of spark.glm", { training <- suppressWarnings(createDataFrame(iris)) @@ -456,3 +453,5 @@ test_that("spark.survreg", { expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) } }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 2552127cc547f..12667d44a853f 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,7 +33,8 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -jsc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests @@ -107,3 +108,5 @@ test_that("parallelize() and collect() work for lists of pairs (pairwise data)", expect_equal(collect(strPairsRDDD1), strPairs) expect_equal(collect(strPairsRDDD2), strPairs) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b6c8e1dc6c1b7..c53fa92cf8766 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,8 @@ context("basic RDD functions") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data nums <- 1:10 @@ -799,3 +800,5 @@ test_that("Test correct concurrency of RRDD.compute()", { count <- callJMethod(zrdd, "count") expect_equal(count, 1000) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d3d0f8a24d01c..590f45bf4c895 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,8 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) @@ -219,3 +220,5 @@ test_that("test partitionBy with string keys", { expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 607bd9c12fa05..62baa924fed7f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -40,19 +40,18 @@ setHiveContext <- function(sc) { error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + assign(".sparkRsession", hiveCtx, envir = .sparkREnv) hiveCtx } unsetHiveContext <- function() { - remove(".sparkRHivesc", envir = .sparkREnv) + remove(".sparkRsession", envir = .sparkREnv) } # Tests for SparkSQL functions in SparkR -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -79,7 +78,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { - expect_equal(sparkRSQL.init(sc), sqlContext) + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) +}) + +test_that("calling sparkRSQL.init returns existing SparkSession", { + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) +}) + +test_that("calling sparkR.session.getOrCreate returns existing SparkSession", { + expect_equal(sparkR.session.getOrCreate(), sparkSession) }) test_that("infer types and check types", { @@ -431,6 +439,7 @@ test_that("read/write json files", { }) test_that("jsonRDD() on a RDD with json string", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(count(rdd), 3) df <- suppressWarnings(jsonRDD(sqlContext, rdd)) @@ -2257,6 +2266,7 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) @@ -2298,6 +2308,7 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index c2c724cdc762f..b88e1fdf011ea 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,10 +30,11 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -jsc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { - numVectorRDD <- parallelize(jsc, numVector, 10) + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) # case: number of elements to take is the same as the size of the first partition @@ -42,20 +43,20 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) - numListRDD <- parallelize(jsc, numList, 1) - numListRDD2 <- parallelize(jsc, numList, 4) + numListRDD <- parallelize(sc, numList, 1) + numListRDD2 <- parallelize(sc, numList, 4) expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) expect_equal(take(numListRDD2, 999), numList) - strVectorRDD <- parallelize(jsc, strVector, 2) - strVectorRDD2 <- parallelize(jsc, strVector, 3) + strVectorRDD <- parallelize(sc, strVector, 2) + strVectorRDD2 <- parallelize(sc, strVector, 3) expect_equal(take(strVectorRDD, 4), as.list(strVector)) expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) - strListRDD <- parallelize(jsc, strList, 4) - strListRDD2 <- parallelize(jsc, strList, 1) + strListRDD <- parallelize(sc, strList, 4) + strListRDD2 <- parallelize(sc, strList, 1) expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) @@ -64,3 +65,5 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(length(take(numListRDD, 0)), 0) expect_equal(length(take(numVectorRDD, 0)), 0) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index e64ef1bb31a3a..84af0b4d04ef3 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,7 +18,8 @@ context("the textFile() function") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") @@ -159,3 +160,5 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) + +#sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 54d2eca50eaf5..4cbe5f72a2e32 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,7 +18,8 @@ context("functions in utils.R") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session.getOrCreate() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { @@ -168,3 +169,5 @@ test_that("convertToJSaveMode", { test_that("hashCode", { expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) + +#sparkR.session.stop() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 18260cfba7439..cb6c60ed1d745 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -22,32 +22,35 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.JavaConverters._ import scala.util.matching.Regex +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.types._ private[sql] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) - def getOrCreateSparkSession(jsc: JavaSparkContext): SparkSession = { - if (SparkSession.hiveClassesArePresent) { - SparkSession.builder().sparkContext(HiveUtils.withHiveExternalCatalog(jsc.sc)).getOrCreate() + def withHiveExternalCatalog(sc: SparkContext): SparkContext = { + sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + sc + } + + def getOrCreateSparkSession(jsc: JavaSparkContext, enableHiveSupport: Boolean): SparkSession = { + if (SparkSession.hiveClassesArePresent && enableHiveSupport) { + SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } } - def createSQLContext(jsc: JavaSparkContext): SQLContext = { - SQLContext.getOrCreate(jsc.sc) - } - - def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { - new JavaSparkContext(sqlCtx.sparkContext) + def getJavaSparkContext(spark: SparkSession): JavaSparkContext = { + new JavaSparkContext(spark.sparkContext) } def createStructType(fields : Seq[StructField]): StructType = { @@ -104,10 +107,10 @@ private[sql] object SQLUtils { StructField(name, dtObj, nullable) } - def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { + def createDF(rdd: RDD[Array[Byte]], schema: StructType, sparkSession: SparkSession): DataFrame = { val num = schema.fields.length val rowRDD = rdd.map(bytesToRow(_, schema)) - sqlContext.createDataFrame(rowRDD, schema) + sparkSession.createDataFrame(rowRDD, schema) } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { @@ -200,18 +203,18 @@ private[sql] object SQLUtils { } def loadDF( - sqlContext: SQLContext, + sparkSession: SparkSession, source: String, options: java.util.Map[String, String]): DataFrame = { - sqlContext.read.format(source).options(options).load() + sparkSession.read.format(source).options(options).load() } def loadDF( - sqlContext: SQLContext, + sparkSession: SparkSession, source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - sqlContext.read.format(source).schema(schema).options(options).load() + sparkSession.read.format(source).schema(schema).options(options).load() } def readSqlObject(dis: DataInputStream, dataType: Char): Object = { @@ -236,4 +239,22 @@ private[sql] object SQLUtils { false } } + + def getTables(sparkSession: SparkSession, databaseName: String): DataFrame = { + databaseName match { + case n: String if n != null && n.trim.nonEmpty => + Dataset.ofRows(sparkSession, ShowTablesCommand(Some(n), None)) + case _ => + Dataset.ofRows(sparkSession, ShowTablesCommand(None, None)) + } + } + + def getTableNames(sparkSession: SparkSession, databaseName: String): Array[String] = { + databaseName match { + case n: String if n != null && n.trim.nonEmpty => + sparkSession.catalog.listTables(n).collect().map(_.name) + case _ => + sparkSession.catalog.listTables().collect().map(_.name) + } + } } From 35688e5ca091df79b903c7f78d6fa0ffe1f55e32 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sat, 11 Jun 2016 04:16:10 -0700 Subject: [PATCH 03/10] fix tests --- R/pkg/R/SQLContext.R | 8 +++--- R/pkg/inst/tests/testthat/test_sparkSQL.R | 32 ++++++++++++++--------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 7d732b35b223f..a2e7b0c804016 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -506,10 +506,10 @@ tables <- function(x, ...) { tableNames.default <- function(databaseName = NULL) { sparkSession <- getSparkSession() - jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "getTableNames", - sparkSession, - databaseName) + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) } tableNames <- function(x, ...) { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 62baa924fed7f..559a911f2ac39 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -33,19 +33,29 @@ markUtf8 <- function(s) { } setHiveContext <- function(sc) { - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - assign(".sparkRsession", hiveCtx, envir = .sparkREnv) - hiveCtx + if (exists(".testHiveSession", envir = .sparkREnv)) { + hiveSession <- get(".testHiveSession", envir = .sparkREnv) + } else { + # initialize once and reuse + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + hiveSession <- callJMethod(hiveCtx, "sparkSession") + } + previousSession <- get(".sparkRsession", envir = .sparkREnv) + assign(".sparkRsession", hiveSession, envir = .sparkREnv) + assign(".prevSparkRsession", previousSession, envir = .sparkREnv) + hiveSession } unsetHiveContext <- function() { - remove(".sparkRsession", envir = .sparkREnv) + previousSession <- get(".prevSparkRsession", envir = .sparkREnv) + assign(".sparkRsession", previousSession, envir = .sparkREnv) + remove(".prevSparkRsession", envir = .sparkREnv) } # Tests for SparkSQL functions in SparkR @@ -2237,7 +2247,6 @@ test_that("gapply() on a DataFrame", { }) test_that("Window functions on a DataFrame", { - setHiveContext(sc) df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), schema = c("key", "value")) ws <- orderBy(window.partitionBy("key"), "value") @@ -2262,7 +2271,6 @@ test_that("Window functions on a DataFrame", { result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") expect_equal(result, expected) - unsetHiveContext() }) test_that("createDataFrame sqlContext parameter backward compatibility", { From fde449118f902b8f5672123373d26a561d9d78a7 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 12 Jun 2016 18:29:45 -0700 Subject: [PATCH 04/10] add support for updating config for existing session, fix read.jdbc --- R/pkg/R/SQLContext.R | 3 +- R/pkg/R/sparkR.R | 48 ++++++++++++++----- R/pkg/R/utils.R | 9 ++++ R/pkg/inst/tests/testthat/test_Serde.R | 2 +- R/pkg/inst/tests/testthat/test_binaryFile.R | 2 +- .../tests/testthat/test_binary_function.R | 2 +- R/pkg/inst/tests/testthat/test_broadcast.R | 2 +- .../inst/tests/testthat/test_includePackage.R | 2 +- R/pkg/inst/tests/testthat/test_mllib.R | 2 +- .../tests/testthat/test_parallelize_collect.R | 2 +- R/pkg/inst/tests/testthat/test_rdd.R | 2 +- R/pkg/inst/tests/testthat/test_shuffle.R | 2 +- R/pkg/inst/tests/testthat/test_take.R | 2 +- R/pkg/inst/tests/testthat/test_textFile.R | 2 +- R/pkg/inst/tests/testthat/test_utils.R | 15 +++++- .../org/apache/spark/sql/api/r/SQLUtils.scala | 21 ++++++-- 16 files changed, 89 insertions(+), 29 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a2e7b0c804016..2053aaf13b494 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -772,11 +772,10 @@ read.jdbc <- function(url, tableName, partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, numPartitions = 0L, predicates = list(), ...) { jprops <- varargsToJProperties(...) - + sparkSession <- getSparkSession() read <- callJMethod(sparkSession, "read") if (!is.null(partitionColumn)) { if (is.null(numPartitions) || numPartitions == 0) { - sparkSession <- getSparkSession() sc <- callJMethod(sparkSession, "sparkContext") numPartitions <- callJMethod(sc, "defaultParallelism") } else { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 7766fa0b8a443..65cb48c84a681 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -122,7 +122,12 @@ sparkR.init <- function( sparkJars = "", sparkPackages = "") { .Deprecated("sparkR.session.getOrCreate") - sparkR.sparkContext(master, appName, sparkHome, sparkEnvir, sparkExecutorEnv, sparkJars, + sparkR.sparkContext(master, + appName, + sparkHome, + convertNamedListToEnv(sparkEnvir), + convertNamedListToEnv(sparkExecutorEnv), + sparkJars, sparkPackages) } @@ -131,22 +136,20 @@ sparkR.sparkContext <- function( master = "", appName = "SparkR", sparkHome = Sys.getenv("SPARK_HOME"), - sparkEnvir = list(), - sparkExecutorEnv = list(), + sparkEnvirMap = new.env(), + sparkExecutorEnvMap = new.env(), sparkJars = "", sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { cat(paste("Re-using existing Spark Context.", - "Please stop SparkR with sparkR.session.stop() or restart R to create a new Spark Context\n")) + "Call sparkR.session.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } jars <- processSparkJars(sparkJars) packages <- processSparkPackages(sparkPackages) - sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) - existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") if (existingPort != "") { backendPort <- existingPort @@ -204,7 +207,6 @@ sparkR.sparkContext <- function( sparkHome <- suppressWarnings(normalizePath(sparkHome)) } - sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) if (is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:", Sys.getenv("LD_LIBRARY_PATH")) @@ -298,6 +300,9 @@ sparkRHive.init <- function(jsc = NULL) { #' Get the existing SparkSession or initialize a new SparkSession. #' +#' Additional Spark properties can be set (...), and these named parameters takes priority over +#' over values in master, appName, named lists of sparkConfig. +#' #' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory @@ -313,13 +318,13 @@ sparkRHive.init <- function(jsc = NULL) { #' df <- read.json(path) #' #' sparkR.session.getOrCreate("local[2]", "SparkR", "/home/spark") -#' sparkR.session.getOrCreate("local[2]", "SparkR", "/home/spark", -#' list(spark.executor.memory="1g")) #' sparkR.session.getOrCreate("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), #' c("one.jar", "two.jar", "three.jar"), #' c("com.databricks:spark-avro_2.10:2.0.1")) +#' sparkR.session.getOrCreate(spark.master = "yarn-client", +#' spark.executor.memory = "4g") #'} #' @note since 2.0.0 @@ -334,20 +339,39 @@ sparkR.session.getOrCreate <- function( enableHiveSupport = TRUE, ...) { + sparkConfigMap <- convertNamedListToEnv(sparkConfig) + namedParams <- list(...) + if (length(namedParams) > 0) { + paramMap <- convertNamedListToEnv(namedParams) + # Override for certain named parameters + if (exists("spark.master", envir = paramMap)) { + master = paramMap[["spark.master"]] + } + if (exists("spark.app.name", envir = paramMap)) { + appName = paramMap[["spark.app.name"]] + } + overrideEnvs(sparkConfigMap, paramMap) + } + + sparkExecutorConfigMap <- convertNamedListToEnv(sparkExecutorConfig) if (!exists(".sparkRjsc", envir = .sparkREnv)) { - sparkR.sparkContext(master, appName, sparkHome, sparkConfig, sparkExecutorConfig, sparkJars, - sparkPackages) + sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorConfigMap, + sparkJars, sparkPackages) stopifnot(exists(".sparkRjsc", envir = .sparkREnv)) } if (exists(".sparkRsession", envir = .sparkREnv)) { sparkSession <- get(".sparkRsession", envir = .sparkREnv) - # TODO: apply config + # Apply config to Spark Context and Spark Session if already there + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "setSparkContextSessionConf", + sparkConfigMap) } else { jsc <- get(".sparkRjsc", envir = .sparkREnv) sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getOrCreateSparkSession", jsc, + sparkConfigMap, enableHiveSupport) assign(".sparkRsession", sparkSession, envir = .sparkREnv) } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index b1b8adaa66a25..aafb34472feb1 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -317,6 +317,15 @@ convertEnvsToList <- function(keys, vals) { }) } +# Utility function to merge 2 environments with the second overriding values in the first +# env1 is changed in place +overrideEnvs <- function(env1, env2) { + lapply(ls(env2), + function(name) { + env1[[name]] <- env2[[name]] + }) +} + # Utility function to capture the varargs into environment object varargsToEnv <- function(...) { # Based on http://stackoverflow.com/a/3057419/4577954 diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index ce7a593d7f62a..1bc97e21e933a 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -76,4 +76,4 @@ test_that("SerDe of list of lists", { expect_equal(x, y) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 03ee52bea03cf..501fcaedd2887 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -89,4 +89,4 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { unlink(fileName2, recursive = TRUE) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index acd42e4ea7322..f47c1ce9c644e 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -101,4 +101,4 @@ test_that("zipPartitions() on RDDs", { unlink(fileName) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index c0fe9d1315139..b9bd080a444e5 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -48,4 +48,4 @@ test_that("without using broadcast variable", { expect_equal(actual, expected) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 697a9de44c6e8..65fb56d971d67 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -57,4 +57,4 @@ test_that("use include package", { } }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 9efd26e28b3f0..1bcc6652a8116 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -454,4 +454,4 @@ test_that("spark.survreg", { } }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 12667d44a853f..765e797acd012 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -109,4 +109,4 @@ test_that("parallelize() and collect() work for lists of pairs (pairwise data)", expect_equal(collect(strPairsRDDD2), strPairs) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index c53fa92cf8766..f957908a56d85 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -801,4 +801,4 @@ test_that("Test correct concurrency of RRDD.compute()", { expect_equal(count, 1000) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index 590f45bf4c895..57c41bc7c3dac 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -221,4 +221,4 @@ test_that("test partitionBy with string keys", { expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index b88e1fdf011ea..3cb2de33a514b 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -66,4 +66,4 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(length(take(numVectorRDD, 0)), 0) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 84af0b4d04ef3..be19b983d88c5 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -161,4 +161,4 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) -#sparkR.session.stop() +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 4cbe5f72a2e32..3939bc7cc3059 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -170,4 +170,17 @@ test_that("hashCode", { expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) -#sparkR.session.stop() +test_that("overrideEnvs", { + config <- new.env() + config[["spark.master"]] <- "foo" + config[["config_only"]] <- "ok" + param <- new.env() + param[["spark.master"]] <- "local" + param[["param_only"]] <- "blah" + overrideEnvs(config, param) + expect_equal(config[["spark.master"]], "local") + expect_equal(config[["param_only"]], "blah") + expect_equal(config[["config_only"]], "ok") +}) + +sparkR.session.stop() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index cb6c60ed1d745..d125099ea429c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.{Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex @@ -36,17 +37,31 @@ import org.apache.spark.sql.types._ private[sql] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) - def withHiveExternalCatalog(sc: SparkContext): SparkContext = { + private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = { sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") sc } - def getOrCreateSparkSession(jsc: JavaSparkContext, enableHiveSupport: Boolean): SparkSession = { - if (SparkSession.hiveClassesArePresent && enableHiveSupport) { + def getOrCreateSparkSession( + jsc: JavaSparkContext, + sparkConfigMap: JMap[Object, Object], + enableHiveSupport: Boolean): SparkSession = { + val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } + setSparkContextSessionConf(spark, sparkConfigMap) + spark + } + + def setSparkContextSessionConf(spark: SparkSession, sparkConfigMap: JMap[Object, Object]) = { + for ((name, value) <- sparkConfigMap.asScala) { + spark.conf.set(name.toString, value.toString) + } + for ((name, value) <- sparkConfigMap.asScala) { + spark.sparkContext.conf.set(name.toString, value.toString) + } } def getJavaSparkContext(spark: SparkSession): JavaSparkContext = { From e0750eb4f5bf299e49503bdfe02437633e44b572 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 12 Jun 2016 23:32:36 -0700 Subject: [PATCH 05/10] update doc --- R/pkg/R/sparkR.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 65cb48c84a681..fcb7b0a267e6e 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -279,6 +279,9 @@ sparkRSQL.init <- function(jsc = NULL) { #' #' This function creates a HiveContext from an existing JavaSparkContext #' +#' Starting SparkR 2.0, a SparkSession is initialized and returned instead. +#' This API is deprecated and kept for backward compatibility only. +#' #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @export #' @examples From 88b200f3ad2b1553675d4090408c4c981c83db5f Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 12 Jun 2016 23:42:11 -0700 Subject: [PATCH 06/10] fix scalastyle --- .../src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d125099ea429c..c86f9777d6096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -55,7 +55,9 @@ private[sql] object SQLUtils { spark } - def setSparkContextSessionConf(spark: SparkSession, sparkConfigMap: JMap[Object, Object]) = { + def setSparkContextSessionConf( + spark: SparkSession, + sparkConfigMap: JMap[Object, Object]): Unit = { for ((name, value) <- sparkConfigMap.asScala) { spark.conf.set(name.toString, value.toString) } From c4d24c2c484d2138c3f320e145259cfc64064d14 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 16 Jun 2016 16:28:54 -0700 Subject: [PATCH 07/10] more test, comment feedback --- R/pkg/R/sparkR.R | 1 - R/pkg/inst/tests/testthat/test_sparkSQL.R | 26 ++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index fcb7b0a267e6e..9c024c78b240b 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -33,7 +33,6 @@ connExists <- function(env) { #' Also terminates the backend this R session is connected to #' @export sparkR.stop <- function() { - .Deprecated("sparkR.session.stop") sparkR.session.stop() } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 559a911f2ac39..cd0ce561ee09e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2301,7 +2301,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { test_that("randomSplit", { num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) - weights <- c(2, 3, 5) df_list <- randomSplit(df, weights) expect_equal(length(weights), length(df_list)) @@ -2316,6 +2315,31 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) +test_that("Change config on SparkSession", { + conf <- callJMethod(sparkSession, "conf") + property <- paste0("spark.testing.", as.character(runif(1))) + value <- as.character(runif(1)) + callJMethod(conf, "set", property, value) + + value <- as.character(runif(1)) + l <- list(value) + names(l) <- property + sparkR.session.getOrCreate(l) + + conf <- callJMethod(sparkSession, "conf") + newValue <- callJMethod(conf, "get", property, "") + + expect_equal(value, newValue) +}) + +test_that("enableHiveSupport on SparkSession", { + setHiveContext(sc) + unsetHiveContext() + # if we are still here, it must be built with hive + conf <- callJMethod(sparkSession, "conf") + value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "") + expect_equal(value, "hive") +}) unlink(parquetPath) unlink(jsonPath) From 310b2cf67a198a3a820a2460758bc64ff012ac39 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 16 Jun 2016 18:12:42 -0700 Subject: [PATCH 08/10] review feedback, --- R/pkg/R/backend.R | 2 +- R/pkg/R/sparkR.R | 34 +++++++++---------- R/pkg/inst/tests/testthat/jarTest.R | 3 +- R/pkg/inst/tests/testthat/packageInAJarTest.R | 2 +- R/pkg/inst/tests/testthat/test_Serde.R | 2 +- R/pkg/inst/tests/testthat/test_binaryFile.R | 2 +- .../tests/testthat/test_binary_function.R | 2 +- R/pkg/inst/tests/testthat/test_broadcast.R | 2 +- R/pkg/inst/tests/testthat/test_context.R | 2 +- .../inst/tests/testthat/test_includePackage.R | 2 +- R/pkg/inst/tests/testthat/test_mllib.R | 2 +- .../tests/testthat/test_parallelize_collect.R | 2 +- R/pkg/inst/tests/testthat/test_rdd.R | 2 +- R/pkg/inst/tests/testthat/test_shuffle.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 8 ++--- R/pkg/inst/tests/testthat/test_take.R | 2 +- R/pkg/inst/tests/testthat/test_textFile.R | 2 +- R/pkg/inst/tests/testthat/test_utils.R | 2 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 7 +++- 19 files changed, 43 insertions(+), 39 deletions(-) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 860511d48d0d3..03e70bb2cb82e 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -68,7 +68,7 @@ isRemoveMethod <- function(isStatic, objId, methodName) { # methodName - name of method to be invoked invokeJava <- function(isStatic, objId, methodName, ...) { if (!exists(".sparkRCon", .sparkREnv)) { - stop("No connection to backend found. Please re-run sparkR.session.getOrCreate()") + stop("No connection to backend found. Please re-run sparkR.session()") } # If this isn't a removeJObject call diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 9c024c78b240b..23a5b81b2d338 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -28,9 +28,8 @@ connExists <- function(env) { }) } -#' Stop the Spark context. -#' -#' Also terminates the backend this R session is connected to +#' @rdname sparkR.session.stop +#' @name sparkR.stop #' @export sparkR.stop <- function() { sparkR.session.stop() @@ -39,6 +38,8 @@ sparkR.stop <- function() { #' Stop the Spark Session and Spark Context. #' #' Also terminates the backend this R session is connected to. +#' @rdname sparkR.session.stop +#' @name sparkR.session.stop #' @export #' @note since 2.0.0 sparkR.session.stop <- function() { @@ -120,7 +121,7 @@ sparkR.init <- function( sparkExecutorEnv = list(), sparkJars = "", sparkPackages = "") { - .Deprecated("sparkR.session.getOrCreate") + .Deprecated("sparkR.session") sparkR.sparkContext(master, appName, sparkHome, @@ -264,14 +265,14 @@ sparkR.sparkContext <- function( #'} sparkRSQL.init <- function(jsc = NULL) { - .Deprecated("sparkR.session.getOrCreate") + .Deprecated("sparkR.session") if (exists(".sparkRsession", envir = .sparkREnv)) { return(get(".sparkRsession", envir = .sparkREnv)) } # Default to without Hive support for backward compatibility. - sparkR.session.getOrCreate(enableHiveSupport = FALSE) + sparkR.session(enableHiveSupport = FALSE) } #' Initialize a new HiveContext. @@ -290,14 +291,14 @@ sparkRSQL.init <- function(jsc = NULL) { #'} sparkRHive.init <- function(jsc = NULL) { - .Deprecated("sparkR.session.getOrCreate") + .Deprecated("sparkR.session") if (exists(".sparkRsession", envir = .sparkREnv)) { return(get(".sparkRsession", envir = .sparkREnv)) } # Default to without Hive support for backward compatibility. - sparkR.session.getOrCreate(enableHiveSupport = TRUE) + sparkR.session(enableHiveSupport = TRUE) } #' Get the existing SparkSession or initialize a new SparkSession. @@ -309,33 +310,32 @@ sparkRHive.init <- function(jsc = NULL) { #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory #' @param sparkConfig Named list of Spark configuration to set on worker nodes -#' @param sparkExecutorConfig Named list of Spark configuration to be used when launching executors +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors #' @param sparkJars Character vector of jar files to pass to the worker nodes #' @param sparkPackages Character vector of packages from spark-packages.org #' @param enableHiveSupport Enable support for Hive #' @export #' @examples #'\dontrun{ -#' sparkR.session.getOrCreate() +#' sparkR.session() #' df <- read.json(path) #' -#' sparkR.session.getOrCreate("local[2]", "SparkR", "/home/spark") -#' sparkR.session.getOrCreate("yarn-client", "SparkR", "/home/spark", +#' sparkR.session("local[2]", "SparkR", "/home/spark") +#' sparkR.session("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), #' c("one.jar", "two.jar", "three.jar"), #' c("com.databricks:spark-avro_2.10:2.0.1")) -#' sparkR.session.getOrCreate(spark.master = "yarn-client", +#' sparkR.session(spark.master = "yarn-client", #' spark.executor.memory = "4g") #'} #' @note since 2.0.0 -sparkR.session.getOrCreate <- function( +sparkR.session <- function( master = "", appName = "SparkR", sparkHome = Sys.getenv("SPARK_HOME"), sparkConfig = list(), - sparkExecutorConfig = list(), sparkJars = "", sparkPackages = "", enableHiveSupport = TRUE, @@ -355,9 +355,9 @@ sparkR.session.getOrCreate <- function( overrideEnvs(sparkConfigMap, paramMap) } - sparkExecutorConfigMap <- convertNamedListToEnv(sparkExecutorConfig) if (!exists(".sparkRjsc", envir = .sparkREnv)) { - sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorConfigMap, + sparkExecutorEnvMap <- new.env() + sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, sparkJars, sparkPackages) stopifnot(exists(".sparkRjsc", envir = .sparkREnv)) } diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index a56cf81c82bfc..84e4845f180b3 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,8 +16,7 @@ # library(SparkR) -sparkSession <- sparkR.session.getOrCreate() -sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) +sparkSession <- sparkR.session() helloTest <- SparkR:::callJStatic("sparkR.test.hello", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index 746f60b857e45..940c91f376cd5 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index 1bc97e21e933a..b45e9ddcd2942 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 501fcaedd2887..dc0581c61dc51 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,7 +18,7 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index f47c1ce9c644e..f0b90d5a00fdd 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index b9bd080a444e5..2c23ee140e2fb 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index a143e14f2e3c2..5dfb74757cd69 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -56,7 +56,7 @@ test_that("Check masked functions", { # test_that("repeatedly starting and stopping SparkR", { # for (i in 1:4) { -# sparkR.session.getOrCreate() +# sparkR.session(enableHiveSupport = FALSE) # df <- createDataFrame(data.frame(dummy=1:i)) # expect_equal(count(df), i) # sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 65fb56d971d67..bb4682f8afe3e 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1bcc6652a8116..0fddf2e24e53b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib functions") # Tests for MLlib functions in SparkR -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() test_that("formula of spark.glm", { training <- suppressWarnings(createDataFrame(iris)) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 765e797acd012..fc3dba69361a1 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,7 +33,7 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index f957908a56d85..cc61ad0e3f189 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index 57c41bc7c3dac..920ab4b8d3ccd 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cd0ce561ee09e..c6efaf7b822b2 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -60,7 +60,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -96,8 +96,8 @@ test_that("calling sparkRSQL.init returns existing SparkSession", { expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) -test_that("calling sparkR.session.getOrCreate returns existing SparkSession", { - expect_equal(sparkR.session.getOrCreate(), sparkSession) +test_that("calling sparkR.session returns existing SparkSession", { + expect_equal(sparkR.session(), sparkSession) }) test_that("infer types and check types", { @@ -2324,7 +2324,7 @@ test_that("Change config on SparkSession", { value <- as.character(runif(1)) l <- list(value) names(l) <- property - sparkR.session.getOrCreate(l) + sparkR.session(l) conf <- callJMethod(sparkSession, "conf") newValue <- callJMethod(conf, "get", property, "") diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index 3cb2de33a514b..d564d8b66800b 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,7 +30,7 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index be19b983d88c5..d4a58698e632a 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,7 +18,7 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 3939bc7cc3059..69946a17da281 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,7 +18,7 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session.getOrCreate() +sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index c86f9777d6096..0a995d2e9d180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -23,6 +23,7 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex +import org.apache.spark.internal.Logging import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.types._ -private[sql] object SQLUtils { +private[sql] object SQLUtils extends Logging { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = { @@ -49,6 +50,10 @@ private[sql] object SQLUtils { val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { + if (enableHiveSupport) { + logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + + "Spark is not built with Hive; falling back to without Hive support.") + } SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } setSparkContextSessionConf(spark, sparkConfigMap) From 3a413c562f4144683a95373483bbc84dfa75db4c Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 16 Jun 2016 23:42:41 -0700 Subject: [PATCH 09/10] more update --- R/pkg/NAMESPACE | 2 +- R/pkg/R/SQLContext.R | 3 +- R/pkg/R/sparkR.R | 34 +++++++++++-------- R/pkg/inst/profile/shell.R | 8 ++--- R/pkg/inst/tests/testthat/test_Serde.R | 2 -- R/pkg/inst/tests/testthat/test_binaryFile.R | 2 -- .../tests/testthat/test_binary_function.R | 2 -- R/pkg/inst/tests/testthat/test_broadcast.R | 2 -- R/pkg/inst/tests/testthat/test_context.R | 12 +++++++ .../inst/tests/testthat/test_includePackage.R | 2 -- R/pkg/inst/tests/testthat/test_mllib.R | 2 -- .../tests/testthat/test_parallelize_collect.R | 2 -- R/pkg/inst/tests/testthat/test_rdd.R | 2 -- R/pkg/inst/tests/testthat/test_shuffle.R | 2 -- R/pkg/inst/tests/testthat/test_sparkSQL.R | 21 ++++++++---- R/pkg/inst/tests/testthat/test_take.R | 2 -- R/pkg/inst/tests/testthat/test_textFile.R | 2 -- R/pkg/inst/tests/testthat/test_utils.R | 2 -- 18 files changed, 54 insertions(+), 50 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 835181c2fe393..82e56ca437299 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -6,7 +6,7 @@ importFrom(methods, setGeneric, setMethod, setOldClass) #useDynLib(SparkR, stringHashCode) # S3 methods exported -export("sparkR.session.getOrCreate") +export("sparkR.session") export("sparkR.init") export("sparkR.stop") export("sparkR.session.stop") diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 2053aaf13b494..3232241f8af55 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -673,7 +673,8 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, options) + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "loadDF", sparkSession, source, options) } dataFrame(sdf) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 23a5b81b2d338..0dfd7b753033e 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -87,7 +87,7 @@ sparkR.session.stop <- function() { clearJobjs() } -#' Initialize a new Spark Context. +#' (Deprecated) Initialize a new Spark Context. #' #' This function initializes a new SparkContext. For details on how to initialize #' and use SparkR, refer to SparkR programming guide at @@ -100,6 +100,8 @@ sparkR.session.stop <- function() { #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors #' @param sparkJars Character vector of jar files to pass to the worker nodes #' @param sparkPackages Character vector of packages from spark-packages.org +#' @seealso \link{sparkR.session} +#' @rdname sparkR.init-deprecated #' @export #' @examples #'\dontrun{ @@ -248,7 +250,7 @@ sparkR.sparkContext <- function( sc } -#' Initialize a new SQLContext. +#' (Deprecated) Initialize a new SQLContext. #' #' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext @@ -257,6 +259,8 @@ sparkR.sparkContext <- function( #' This API is deprecated and kept for backward compatibility only. #' #' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @seealso \link{sparkR.session} +#' @rdname sparkRSQL.init-deprecated #' @export #' @examples #'\dontrun{ @@ -275,7 +279,7 @@ sparkRSQL.init <- function(jsc = NULL) { sparkR.session(enableHiveSupport = FALSE) } -#' Initialize a new HiveContext. +#' (Deprecated) Initialize a new HiveContext. #' #' This function creates a HiveContext from an existing JavaSparkContext #' @@ -283,6 +287,8 @@ sparkRSQL.init <- function(jsc = NULL) { #' This API is deprecated and kept for backward compatibility only. #' #' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @seealso \link{sparkR.session} +#' @rdname sparkRHive.init-deprecated #' @export #' @examples #'\dontrun{ @@ -303,17 +309,17 @@ sparkRHive.init <- function(jsc = NULL) { #' Get the existing SparkSession or initialize a new SparkSession. #' -#' Additional Spark properties can be set (...), and these named parameters takes priority over +#' Additional Spark properties can be set (...), and these named parameters take priority over #' over values in master, appName, named lists of sparkConfig. #' #' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory #' @param sparkConfig Named list of Spark configuration to set on worker nodes -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors #' @param sparkJars Character vector of jar files to pass to the worker nodes #' @param sparkPackages Character vector of packages from spark-packages.org -#' @param enableHiveSupport Enable support for Hive +#' @param enableHiveSupport Enable support for Hive, fallback if not built with Hive support; once +#' set, this cannot be turned off on an existing session #' @export #' @examples #'\dontrun{ @@ -322,12 +328,10 @@ sparkRHive.init <- function(jsc = NULL) { #' #' sparkR.session("local[2]", "SparkR", "/home/spark") #' sparkR.session("yarn-client", "SparkR", "/home/spark", -#' list(spark.executor.memory="4g"), -#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.10:2.0.1")) -#' sparkR.session(spark.master = "yarn-client", -#' spark.executor.memory = "4g") +#' list(spark.executor.memory="4g"), +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1")) +#' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g") #'} #' @note since 2.0.0 @@ -347,10 +351,10 @@ sparkR.session <- function( paramMap <- convertNamedListToEnv(namedParams) # Override for certain named parameters if (exists("spark.master", envir = paramMap)) { - master = paramMap[["spark.master"]] + master <- paramMap[["spark.master"]] } if (exists("spark.app.name", envir = paramMap)) { - appName = paramMap[["spark.app.name"]] + appName <- paramMap[["spark.app.name"]] } overrideEnvs(sparkConfigMap, paramMap) } @@ -365,8 +369,10 @@ sparkR.session <- function( if (exists(".sparkRsession", envir = .sparkREnv)) { sparkSession <- get(".sparkRsession", envir = .sparkREnv) # Apply config to Spark Context and Spark Session if already there + # Cannot change enableHiveSupport callJStatic("org.apache.spark.sql.api.r.SQLUtils", "setSparkContextSessionConf", + sparkSession, sparkConfigMap) } else { jsc <- get(".sparkRjsc", envir = .sparkREnv) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index d1bc6c43b2d96..919821196a7aa 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -18,16 +18,16 @@ .First <- function() { home <- Sys.getenv("SPARK_HOME") .libPaths(c(file.path(home, "R", "lib"), .libPaths())) - Sys.setenv(NOAWT=1) + Sys.setenv(NOAWT = 1) # Make sure SparkR package is the last loaded one old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - spark <- SparkR::sparkR.session.getOrCreate() - assign("spark", spark, envir=.GlobalEnv) + spark <- SparkR::sparkR.session() + assign("spark", spark, envir = .GlobalEnv) sc <- SparkR:::callJMethod(spark, "sparkContext") - assign("sc", sc, envir=.GlobalEnv) + assign("sc", sc, envir = .GlobalEnv) sparkVer <- SparkR:::callJMethod(sc, "version") cat("\n Welcome to") cat("\n") diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b45e9ddcd2942..96fb6dda26450 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -75,5 +75,3 @@ test_that("SerDe of list of lists", { y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index dc0581c61dc51..b69f017de81d1 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -88,5 +88,3 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index f0b90d5a00fdd..6f51d20687277 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -100,5 +100,3 @@ test_that("zipPartitions() on RDDs", { unlink(fileName) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 2c23ee140e2fb..cf1d43277105e 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -47,5 +47,3 @@ test_that("without using broadcast variable", { expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) expect_equal(actual, expected) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 5dfb74757cd69..f123187adf3ef 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -54,6 +54,17 @@ test_that("Check masked functions", { sort(namesOfMaskedCompletely, na.last = TRUE)) }) +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- suppressWarnings(sparkR.init()) + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + suppressWarnings(sparkR.stop()) + } +}) + +# Does not work consistently even with Hive off +# nolint start # test_that("repeatedly starting and stopping SparkR", { # for (i in 1:4) { # sparkR.session(enableHiveSupport = FALSE) @@ -63,6 +74,7 @@ test_that("Check masked functions", { # Sys.sleep(5) # Need more time to shutdown Hive metastore # } # }) +# nolint end test_that("rdd GC across sparkR.stop", { sc <- sparkR.sparkContext() # sc should get id 0 diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index bb4682f8afe3e..d6a3766539c02 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -56,5 +56,3 @@ test_that("use include package", { actual <- collect(data) } }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 0fddf2e24e53b..c8c5ef2476b32 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -453,5 +453,3 @@ test_that("spark.survreg", { expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) } }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index fc3dba69361a1..f79a8a70aafb1 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -108,5 +108,3 @@ test_that("parallelize() and collect() work for lists of pairs (pairwise data)", expect_equal(collect(strPairsRDDD1), strPairs) expect_equal(collect(strPairsRDDD2), strPairs) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index cc61ad0e3f189..429311d2924f0 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -800,5 +800,3 @@ test_that("Test correct concurrency of RRDD.compute()", { count <- callJMethod(zrdd, "count") expect_equal(count, 1000) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index 920ab4b8d3ccd..7d4f342016441 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -220,5 +220,3 @@ test_that("test partitionBy with string keys", { expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c6efaf7b822b2..fcc2ab3ed6a2b 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2316,20 +2316,29 @@ test_that("randomSplit", { }) test_that("Change config on SparkSession", { + # first, set it to a random but known value conf <- callJMethod(sparkSession, "conf") property <- paste0("spark.testing.", as.character(runif(1))) - value <- as.character(runif(1)) - callJMethod(conf, "set", property, value) + value1 <- as.character(runif(1)) + callJMethod(conf, "set", property, value1) - value <- as.character(runif(1)) - l <- list(value) + # next, change the same property to the new value + value2 <- as.character(runif(1)) + l <- list(value2) names(l) <- property - sparkR.session(l) + sparkR.session(sparkConfig = l) conf <- callJMethod(sparkSession, "conf") newValue <- callJMethod(conf, "get", property, "") + expect_equal(value2, newValue) - expect_equal(value, newValue) + value <- as.character(runif(1)) + sparkR.session(spark.app.name = "sparkSession test", spark.testing.r.session.r = value) + conf <- callJMethod(sparkSession, "conf") + appNameValue <- callJMethod(conf, "get", "spark.app.name", "") + testValue <- callJMethod(conf, "get", "spark.testing.r.session.r", "") + expect_equal(appNameValue, "sparkSession test") + expect_equal(testValue, value) }) test_that("enableHiveSupport on SparkSession", { diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index d564d8b66800b..daf5e41abe13f 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -65,5 +65,3 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(length(take(numListRDD, 0)), 0) expect_equal(length(take(numVectorRDD, 0)), 0) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index d4a58698e632a..7b2cc74753fe2 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -160,5 +160,3 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 69946a17da281..21a119a06b937 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -182,5 +182,3 @@ test_that("overrideEnvs", { expect_equal(config[["param_only"]], "blah") expect_equal(config[["config_only"]], "ok") }) - -sparkR.session.stop() From 4bc544938c7984877d568b98c725552a51aa3c01 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 17 Jun 2016 17:23:37 -0700 Subject: [PATCH 10/10] sc should be JavaSparkContext --- R/pkg/inst/profile/shell.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 919821196a7aa..8a8111a8c5419 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -26,7 +26,7 @@ spark <- SparkR::sparkR.session() assign("spark", spark, envir = .GlobalEnv) - sc <- SparkR:::callJMethod(spark, "sparkContext") + sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", spark) assign("sc", sc, envir = .GlobalEnv) sparkVer <- SparkR:::callJMethod(sc, "version") cat("\n Welcome to")