Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ importFrom(methods, setGeneric, setMethod, setOldClass)
#useDynLib(SparkR, stringHashCode)

# S3 methods exported
export("sparkR.session")
export("sparkR.init")
export("sparkR.stop")
export("sparkR.session.stop")
export("print.jobj")

export("sparkRSQL.init",
"sparkRHive.init")

# MLlib integration
exportMethods("glm",
"spark.glm",
Expand Down Expand Up @@ -287,9 +292,6 @@ exportMethods("%in%",
exportClasses("GroupedData")
exportMethods("agg")

export("sparkRSQL.init",
"sparkRHive.init")

export("as.DataFrame",
"cacheTable",
"clearCache",
Expand Down
8 changes: 2 additions & 6 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2333,9 +2333,7 @@ setMethod("write.df",
signature(df = "SparkDataFrame", path = "character"),
function(df, path, source = NULL, mode = "error", ...){
if (is.null(source)) {
sqlContext <- getSqlContext()
source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
"org.apache.spark.sql.parquet")
source <- getDefaultSqlSource()
}
jmode <- convertToJSaveMode(mode)
options <- varargsToEnv(...)
Expand Down Expand Up @@ -2393,9 +2391,7 @@ setMethod("saveAsTable",
signature(df = "SparkDataFrame", tableName = "character"),
function(df, tableName, source = NULL, mode="error", ...){
if (is.null(source)) {
sqlContext <- getSqlContext()
source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
"org.apache.spark.sql.parquet")
source <- getDefaultSqlSource()
}
jmode <- convertToJSaveMode(mode)
options <- varargsToEnv(...)
Expand Down
109 changes: 57 additions & 52 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ dispatchFunc <- function(newFuncSig, x, ...) {
# Strip sqlContext from list of parameters and then pass the rest along.
contextNames <- c("org.apache.spark.sql.SQLContext",
"org.apache.spark.sql.hive.HiveContext",
"org.apache.spark.sql.hive.test.TestHiveContext")
"org.apache.spark.sql.hive.test.TestHiveContext",
"org.apache.spark.sql.SparkSession")
if (missing(x) && length(list(...)) == 0) {
f()
} else if (class(x) == "jobj" &&
Expand All @@ -65,14 +66,12 @@ dispatchFunc <- function(newFuncSig, x, ...) {
}
}

#' return the SQL Context
getSqlContext <- function() {
if (exists(".sparkRHivesc", envir = .sparkREnv)) {
get(".sparkRHivesc", envir = .sparkREnv)
} else if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
get(".sparkRSQLsc", envir = .sparkREnv)
#' return the SparkSession
getSparkSession <- function() {
if (exists(".sparkRsession", envir = .sparkREnv)) {
get(".sparkRsession", envir = .sparkREnv)
} else {
stop("SQL context not initialized")
stop("SparkSession not initialized")
}
}

Expand Down Expand Up @@ -109,6 +108,13 @@ infer_type <- function(x) {
}
}

getDefaultSqlSource <- function() {
sparkSession <- getSparkSession()
conf <- callJMethod(sparkSession, "conf")
source <- callJMethod(conf, "get", "spark.sql.sources.default", "org.apache.spark.sql.parquet")
source
}

#' Create a SparkDataFrame
#'
#' Converts R data.frame or list into SparkDataFrame.
Expand All @@ -131,7 +137,7 @@ infer_type <- function(x) {

# TODO(davies): support sampling and infer type from NA
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
if (is.data.frame(data)) {
# get the names of columns, they will be put into RDD
if (is.null(schema)) {
Expand All @@ -158,7 +164,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
data <- do.call(mapply, append(args, data))
}
if (is.list(data)) {
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
rdd <- parallelize(sc, data)
} else if (inherits(data, "RDD")) {
rdd <- data
Expand Down Expand Up @@ -201,7 +207,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
srdd <- callJMethod(jrdd, "rdd")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
srdd, schema$jobj, sqlContext)
srdd, schema$jobj, sparkSession)
dataFrame(sdf)
}

Expand Down Expand Up @@ -265,10 +271,10 @@ setMethod("toDF", signature(x = "RDD"),
#' @method read.json default

read.json.default <- function(path) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
# Allow the user to have a more flexible definiton of the text file path
paths <- as.list(suppressWarnings(normalizePath(path)))
read <- callJMethod(sqlContext, "read")
read <- callJMethod(sparkSession, "read")
sdf <- callJMethod(read, "json", paths)
dataFrame(sdf)
}
Expand Down Expand Up @@ -336,10 +342,10 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) {
#' @method read.parquet default

read.parquet.default <- function(path) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
# Allow the user to have a more flexible definiton of the text file path
paths <- as.list(suppressWarnings(normalizePath(path)))
read <- callJMethod(sqlContext, "read")
read <- callJMethod(sparkSession, "read")
sdf <- callJMethod(read, "parquet", paths)
dataFrame(sdf)
}
Expand Down Expand Up @@ -385,10 +391,10 @@ parquetFile <- function(x, ...) {
#' @method read.text default

read.text.default <- function(path) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
# Allow the user to have a more flexible definiton of the text file path
paths <- as.list(suppressWarnings(normalizePath(path)))
read <- callJMethod(sqlContext, "read")
read <- callJMethod(sparkSession, "read")
sdf <- callJMethod(read, "text", paths)
dataFrame(sdf)
}
Expand Down Expand Up @@ -418,8 +424,8 @@ read.text <- function(x, ...) {
#' @method sql default

sql.default <- function(sqlQuery) {
sqlContext <- getSqlContext()
sdf <- callJMethod(sqlContext, "sql", sqlQuery)
sparkSession <- getSparkSession()
sdf <- callJMethod(sparkSession, "sql", sqlQuery)
dataFrame(sdf)
}

Expand Down Expand Up @@ -449,8 +455,8 @@ sql <- function(x, ...) {
#' @note since 2.0.0

tableToDF <- function(tableName) {
sqlContext <- getSqlContext()
sdf <- callJMethod(sqlContext, "table", tableName)
sparkSession <- getSparkSession()
sdf <- callJMethod(sparkSession, "table", tableName)
dataFrame(sdf)
}

Expand All @@ -472,12 +478,8 @@ tableToDF <- function(tableName) {
#' @method tables default

tables.default <- function(databaseName = NULL) {
sqlContext <- getSqlContext()
jdf <- if (is.null(databaseName)) {
callJMethod(sqlContext, "tables")
} else {
callJMethod(sqlContext, "tables", databaseName)
}
sparkSession <- getSparkSession()
jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName)
dataFrame(jdf)
}

Expand All @@ -503,12 +505,11 @@ tables <- function(x, ...) {
#' @method tableNames default

tableNames.default <- function(databaseName = NULL) {
sqlContext <- getSqlContext()
if (is.null(databaseName)) {
callJMethod(sqlContext, "tableNames")
} else {
callJMethod(sqlContext, "tableNames", databaseName)
}
sparkSession <- getSparkSession()
callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"getTableNames",
sparkSession,
databaseName)
}

tableNames <- function(x, ...) {
Expand Down Expand Up @@ -536,8 +537,9 @@ tableNames <- function(x, ...) {
#' @method cacheTable default

cacheTable.default <- function(tableName) {
sqlContext <- getSqlContext()
callJMethod(sqlContext, "cacheTable", tableName)
sparkSession <- getSparkSession()
catalog <- callJMethod(sparkSession, "catalog")
callJMethod(catalog, "cacheTable", tableName)
}

cacheTable <- function(x, ...) {
Expand Down Expand Up @@ -565,8 +567,9 @@ cacheTable <- function(x, ...) {
#' @method uncacheTable default

uncacheTable.default <- function(tableName) {
sqlContext <- getSqlContext()
callJMethod(sqlContext, "uncacheTable", tableName)
sparkSession <- getSparkSession()
catalog <- callJMethod(sparkSession, "catalog")
callJMethod(catalog, "uncacheTable", tableName)
}

uncacheTable <- function(x, ...) {
Expand All @@ -587,8 +590,9 @@ uncacheTable <- function(x, ...) {
#' @method clearCache default

clearCache.default <- function() {
sqlContext <- getSqlContext()
callJMethod(sqlContext, "clearCache")
sparkSession <- getSparkSession()
catalog <- callJMethod(sparkSession, "catalog")
callJMethod(catalog, "clearCache")
}

clearCache <- function() {
Expand All @@ -615,11 +619,12 @@ clearCache <- function() {
#' @method dropTempTable default

dropTempTable.default <- function(tableName) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
if (class(tableName) != "character") {
stop("tableName must be a string.")
}
callJMethod(sqlContext, "dropTempTable", tableName)
catalog <- callJMethod(sparkSession, "catalog")
callJMethod(catalog, "dropTempView", tableName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but I think we need to deprecate dropTempTable and call it dropTempView in SparkR as well ? cc @liancheng

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. There are several API changes related to catalog that we are not changing here, as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will open a JIRA on this, I have the fix.

}

dropTempTable <- function(x, ...) {
Expand Down Expand Up @@ -655,21 +660,21 @@ dropTempTable <- function(x, ...) {
#' @method read.df default

read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
options[["path"]] <- path
}
if (is.null(source)) {
source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
"org.apache.spark.sql.parquet")
source <- getDefaultSqlSource()
}
if (!is.null(schema)) {
stopifnot(class(schema) == "structType")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source,
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source,
schema$jobj, options)
} else {
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options)
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"loadDF", sparkSession, source, options)
}
dataFrame(sdf)
}
Expand Down Expand Up @@ -715,12 +720,13 @@ loadDF <- function(x, ...) {
#' @method createExternalTable default

createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) {
sqlContext <- getSqlContext()
sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
options[["path"]] <- path
}
sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options)
catalog <- callJMethod(sparkSession, "catalog")
sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options)
dataFrame(sdf)
}

Expand Down Expand Up @@ -767,12 +773,11 @@ read.jdbc <- function(url, tableName,
partitionColumn = NULL, lowerBound = NULL, upperBound = NULL,
numPartitions = 0L, predicates = list(), ...) {
jprops <- varargsToJProperties(...)

read <- callJMethod(sqlContext, "read")
sparkSession <- getSparkSession()
read <- callJMethod(sparkSession, "read")
if (!is.null(partitionColumn)) {
if (is.null(numPartitions) || numPartitions == 0) {
sqlContext <- getSqlContext()
sc <- callJMethod(sqlContext, "sparkContext")
sc <- callJMethod(sparkSession, "sparkContext")
numPartitions <- callJMethod(sc, "defaultParallelism")
} else {
numPartitions <- numToInt(numPartitions)
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/R/backend.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ isRemoveMethod <- function(isStatic, objId, methodName) {
# methodName - name of method to be invoked
invokeJava <- function(isStatic, objId, methodName, ...) {
if (!exists(".sparkRCon", .sparkREnv)) {
stop("No connection to backend found. Please re-run sparkR.init")
stop("No connection to backend found. Please re-run sparkR.session()")
}

# If this isn't a removeJObject call
Expand Down
Loading