From e6edd18456512d416315048bfc46ba14d3ec241c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 9 Oct 2015 14:31:14 -0700 Subject: [PATCH 01/15] Refractor SQLContext and DataFrame functions to lookup sqlContext instance in the env --- R/pkg/R/DataFrame.R | 20 +++-------- R/pkg/R/SQLContext.R | 81 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 76 insertions(+), 25 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0c2a194483b0..2f9832366478 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2209,13 +2209,7 @@ setMethod("write.df", signature(df = "SparkDataFrame", path = "character"), function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { - if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) - } else { - stop("sparkRHive or sparkRSQL context has to be specified") - } + sqlContext <- getSqlContext() source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } @@ -2277,15 +2271,9 @@ setMethod("saveAsTable", signature(df = "SparkDataFrame", tableName = "character"), function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { - if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) - } else { - stop("sparkRHive or sparkRSQL context has to be specified") - } - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + sqlContext <- getSqlContext() + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 6b7a341bee88..dc48f05bfb27 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -35,6 +35,15 @@ getInternalType <- function(x) { POSIXlt = "timestamp", POSIXct = "timestamp", stop(paste("Unsupported type for SparkDataFrame:", class(x)))) +#' return the SQL Context +getSqlContext <- function() { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("SQL context not initialized") + } } #' infer the SQL type @@ -90,6 +99,11 @@ infer_type <- function(x) { #' } # TODO(davies): support sampling and infer type from NA +createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0) { + sqlContext <- getSqlContext() + createDataFrame(sqlContext, data, schema, samplingRatio) +} + createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD @@ -190,13 +204,7 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { - get(".sparkRHivesc", envir = .sparkREnv) - } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - get(".sparkRSQLsc", envir = .sparkREnv) - } else { - stop("no SQL context available") - } + sqlContext <- getSqlContext() createDataFrame(sqlContext, x, ...) }) @@ -225,6 +233,10 @@ read.json <- function(sqlContext, path) { read <- callJMethod(sqlContext, "read") sdf <- callJMethod(read, "json", paths) dataFrame(sdf) + +jsonFile <- function(path) { + sqlContext <- getSqlContext() + jsonFile(sqlContext, path) } #' @rdname read.json @@ -290,6 +302,9 @@ read.parquet <- function(sqlContext, path) { #' @name parquetFile #' @export # TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(...) { + sqlContext <- getSqlContext() + parquetFile(sqlContext, ...) parquetFile <- function(sqlContext, ...) { .Deprecated("read.parquet") read.parquet(sqlContext, unlist(list(...))) @@ -341,6 +356,11 @@ read.text <- function(sqlContext, path) { #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } +sql <- function(sqlQuery) { + sqlContext <- getSqlContext() + sql(sqlContext, sqlQuery) +} + sql <- function(sqlContext, sqlQuery) { sdf <- callJMethod(sqlContext, "sql", sqlQuery) dataFrame(sdf) @@ -387,6 +407,11 @@ tableToDF <- function(sqlContext, tableName) { #' tables(sqlContext, "hive") #' } +tables <- function(databaseName = NULL) { + sqlContext <- getSqlContext() + tables(sqlContext, databaseName) +} + tables <- function(sqlContext, databaseName = NULL) { jdf <- if (is.null(databaseName)) { callJMethod(sqlContext, "tables") @@ -412,6 +437,11 @@ tables <- function(sqlContext, databaseName = NULL) { #' tableNames(sqlContext, "hive") #' } +tableNames <- function(databaseName = NULL) { + sqlContext <- getSqlContext() + tableNames(sqlContext, databaseName) +} + tableNames <- function(sqlContext, databaseName = NULL) { if (is.null(databaseName)) { callJMethod(sqlContext, "tableNames") @@ -439,6 +469,11 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' cacheTable(sqlContext, "table") #' } +cacheTable <- function(tableName) { + sqlContext <- getSqlContext() + cacheTable(sqlContext, tableName) +} + cacheTable <- function(sqlContext, tableName) { callJMethod(sqlContext, "cacheTable", tableName) } @@ -461,6 +496,11 @@ cacheTable <- function(sqlContext, tableName) { #' uncacheTable(sqlContext, "table") #' } +uncacheTable <- function(tableName) { + sqlContext <- getSqlContext() + uncacheTable(sqlContext, tableName) +} + uncacheTable <- function(sqlContext, tableName) { callJMethod(sqlContext, "uncacheTable", tableName) } @@ -475,6 +515,11 @@ uncacheTable <- function(sqlContext, tableName) { #' clearCache(sqlContext) #' } +clearCache <- function() { + sqlContext <- getSqlContext() + callJMethod(sqlContext, "clearCache") +} + clearCache <- function(sqlContext) { callJMethod(sqlContext, "clearCache") } @@ -495,6 +540,11 @@ clearCache <- function(sqlContext) { #' dropTempTable(sqlContext, "table") #' } +dropTempTable <- function(tableName) { + sqlContext <- getSqlContext() + dropTempTable(sqlContext, tableName) +} + dropTempTable <- function(sqlContext, tableName) { if (class(tableName) != "character") { stop("tableName must be a string.") @@ -529,13 +579,17 @@ dropTempTable <- function(sqlContext, tableName) { #' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") #' } +read.df <- function(path = NULL, source = NULL, schema = NULL, ...) { + sqlContext <- getSqlContext() + read.df(sqlContext, path, source, schema, ...) +} + read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } if (is.null(source)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } @@ -549,8 +603,12 @@ read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) dataFrame(sdf) } -#' @rdname read.df +#' @aliases read.df #' @name loadDF +loadDF <- function(path = NULL, source = NULL, schema = NULL, ...) { + read.df(path, source, schema, ...) +} + loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { read.df(sqlContext, path, source, schema, ...) } @@ -577,6 +635,11 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { #' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") #' } +createExternalTable <- function(tableName, path = NULL, source = NULL, ...) { + sqlContext <- getSqlContext() + createExternalTable(sqlContext, tableName, path, source) +} + createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { From d3b332a6cbe9c8e49b86e7e1eb9f72105a114cb1 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 19 Oct 2015 23:31:44 -0700 Subject: [PATCH 02/15] change to get from feedback (+13 squashed commits) Squashed commits: [2c16ca8] add getClassName from feedback [b0348d7] fix the new as.DataFrame method [d8e91f3] fix test [8b3141a] Change to method dispatch update more tests and add tests for back compat [fd3a835] update tests [fa50f78] Improve route logic [efedce5] Method dispatch to support omission of 'sqlContext' argument [612f7f3] Refractor SQLContext and DataFrame functions to lookup sqlContext instance in the env [9382244] fix test [943889f] Change to method dispatch update more tests and add tests for back compat [c5c41c2] update tests [a8e1ea6] Improve route logic [c779c9d] Method dispatch to support omission of 'sqlContext' argument --- R/pkg/R/SQLContext.R | 134 +++++++++++++++++++--------------- R/pkg/R/jobj.R | 5 ++ R/pkg/inst/tests/test_mllib.R | 69 +++++++++++++++++ 3 files changed, 151 insertions(+), 57 deletions(-) create mode 100644 R/pkg/inst/tests/test_mllib.R diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index dc48f05bfb27..d62066a8ee53 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -35,6 +35,23 @@ getInternalType <- function(x) { POSIXlt = "timestamp", POSIXct = "timestamp", stop(paste("Unsupported type for SparkDataFrame:", class(x)))) +#' Temporary function to reroute old S3 Method call to new +#' We need to check the class of x to ensure it is SQLContext before dispatching +dispatchFunc <- function(newFuncSig, x, ...) { + funcName <- as.character(sys.call(sys.parent())[[1]]) + f <- get(paste0(funcName, ".default")) + # Strip sqlContext from list of parameters and then pass the rest along. + # In the following, if '&' is used instead of '&&', it warns about + # "the condition has length > 1 and only the first element will be used" + if (class(x) == "jobj" && + getClassName.jobj(x) == "org.apache.spark.sql.SQLContext") { + .Deprecated(newFuncSig, old = paste0(funcName, "(sqlContext...)")) + f(...) + } else { + f(x, ...) + } +} + #' return the SQL Context getSqlContext <- function() { if (exists(".sparkRHivesc", envir = .sparkREnv)) { @@ -99,12 +116,12 @@ infer_type <- function(x) { #' } # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0) { - sqlContext <- getSqlContext() - createDataFrame(sqlContext, data, schema, samplingRatio) +createDataFrame <- function(x, ...) { + dispatchFunc("createDataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) } -createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { +createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { + sqlContext <- getSqlContext() if (is.data.frame(data)) { # get the names of columns, they will be put into RDD if (is.null(schema)) { @@ -181,8 +198,8 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 #' @rdname createDataFrame #' @aliases createDataFrame #' @export -as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { - createDataFrame(sqlContext, data, schema, samplingRatio) +as.DataFrame <- function(data, schema = NULL, samplingRatio = 1.0) { + createDataFrame(data, schema, samplingRatio) } #' toDF @@ -234,15 +251,15 @@ read.json <- function(sqlContext, path) { sdf <- callJMethod(read, "json", paths) dataFrame(sdf) -jsonFile <- function(path) { - sqlContext <- getSqlContext() - jsonFile(sqlContext, path) -} - #' @rdname read.json #' @name jsonFile #' @export -jsonFile <- function(sqlContext, path) { +jsonFile <- function(x, ...) { + dispatchFunc("jsonFile(path)", x, ...) +} + +jsonFile.default <- function(path) { + sqlContext <- getSqlContext() .Deprecated("read.json") read.json(sqlContext, path) } @@ -266,6 +283,7 @@ jsonFile <- function(sqlContext, path) { #' df <- jsonRDD(sqlContext, rdd) #'} +# TODO: remove - this method is no longer exported # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { .Deprecated("read.json") @@ -302,10 +320,12 @@ read.parquet <- function(sqlContext, path) { #' @name parquetFile #' @export # TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(...) { +parquetFile <- function(x, ...) { + dispatchFunc("parquetFile(...)", x, ...) +} + +parquetFile.default <- function(...) { sqlContext <- getSqlContext() - parquetFile(sqlContext, ...) -parquetFile <- function(sqlContext, ...) { .Deprecated("read.parquet") read.parquet(sqlContext, unlist(list(...))) } @@ -356,14 +376,14 @@ read.text <- function(sqlContext, path) { #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } -sql <- function(sqlQuery) { - sqlContext <- getSqlContext() - sql(sqlContext, sqlQuery) +sql <- function(x, ...) { + dispatchFunc("sql(sqlQuery)", x, ...) } -sql <- function(sqlContext, sqlQuery) { - sdf <- callJMethod(sqlContext, "sql", sqlQuery) - dataFrame(sdf) +sql.default <- function(sqlQuery) { + sqlContext <- getSqlContext() + sdf <- callJMethod(sqlContext, "sql", sqlQuery) + dataFrame(sdf) } #' Create a SparkDataFrame from a SparkSQL Table @@ -407,12 +427,12 @@ tableToDF <- function(sqlContext, tableName) { #' tables(sqlContext, "hive") #' } -tables <- function(databaseName = NULL) { - sqlContext <- getSqlContext() - tables(sqlContext, databaseName) +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) } -tables <- function(sqlContext, databaseName = NULL) { +tables.default <- function(databaseName = NULL) { + sqlContext <- getSqlContext() jdf <- if (is.null(databaseName)) { callJMethod(sqlContext, "tables") } else { @@ -437,12 +457,12 @@ tables <- function(sqlContext, databaseName = NULL) { #' tableNames(sqlContext, "hive") #' } -tableNames <- function(databaseName = NULL) { - sqlContext <- getSqlContext() - tableNames(sqlContext, databaseName) +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) } -tableNames <- function(sqlContext, databaseName = NULL) { +tableNames.default <- function(databaseName = NULL) { + sqlContext <- getSqlContext() if (is.null(databaseName)) { callJMethod(sqlContext, "tableNames") } else { @@ -469,12 +489,12 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' cacheTable(sqlContext, "table") #' } -cacheTable <- function(tableName) { - sqlContext <- getSqlContext() - cacheTable(sqlContext, tableName) +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) } -cacheTable <- function(sqlContext, tableName) { +cacheTable.default <- function(tableName) { + sqlContext <- getSqlContext() callJMethod(sqlContext, "cacheTable", tableName) } @@ -496,12 +516,12 @@ cacheTable <- function(sqlContext, tableName) { #' uncacheTable(sqlContext, "table") #' } -uncacheTable <- function(tableName) { - sqlContext <- getSqlContext() - uncacheTable(sqlContext, tableName) +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) } -uncacheTable <- function(sqlContext, tableName) { +uncacheTable.default <- function(tableName) { + sqlContext <- getSqlContext() callJMethod(sqlContext, "uncacheTable", tableName) } @@ -515,12 +535,12 @@ uncacheTable <- function(sqlContext, tableName) { #' clearCache(sqlContext) #' } -clearCache <- function() { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "clearCache") +clearCache <- function(x, ...) { + dispatchFunc("clearCache()", x, ...) } -clearCache <- function(sqlContext) { +clearCache.default <- function() { + sqlContext <- getSqlContext() callJMethod(sqlContext, "clearCache") } @@ -540,12 +560,12 @@ clearCache <- function(sqlContext) { #' dropTempTable(sqlContext, "table") #' } -dropTempTable <- function(tableName) { - sqlContext <- getSqlContext() - dropTempTable(sqlContext, tableName) +dropTempTable <- function(x, ...) { + dispatchFunc("dropTempTable(tableName)", x, ...) } -dropTempTable <- function(sqlContext, tableName) { +dropTempTable.default <- function(tableName) { + sqlContext <- getSqlContext() if (class(tableName) != "character") { stop("tableName must be a string.") } @@ -579,12 +599,12 @@ dropTempTable <- function(sqlContext, tableName) { #' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") #' } -read.df <- function(path = NULL, source = NULL, schema = NULL, ...) { - sqlContext <- getSqlContext() - read.df(sqlContext, path, source, schema, ...) +read.df <- function(x, ...) { + dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } -read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { +read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { + sqlContext <- getSqlContext() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path @@ -605,12 +625,12 @@ read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) #' @aliases read.df #' @name loadDF -loadDF <- function(path = NULL, source = NULL, schema = NULL, ...) { - read.df(path, source, schema, ...) +loadDF <- function(x, ...) { + dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } -loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { - read.df(sqlContext, path, source, schema, ...) +loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { + read.df(path, source, schema, ...) } #' Create an external table @@ -635,12 +655,12 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { #' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") #' } -createExternalTable <- function(tableName, path = NULL, source = NULL, ...) { - sqlContext <- getSqlContext() - createExternalTable(sqlContext, tableName, path, source) +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) } -createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { +createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { + sqlContext <- getSqlContext() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index 0838a7bb35e0..898e80648fc2 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -77,6 +77,11 @@ print.jobj <- function(x, ...) { cat("Java ref type", name, "id", x$id, "\n", sep = " ") } +getClassName.jobj <- function(x) { + cls <- callJMethod(x, "getClass") + callJMethod(cls, "getName") +} + cleanup.jobj <- function(jobj) { if (isValidJobj(jobj)) { objId <- jobj$id diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 000000000000..7cc7a4227e2d --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(iris) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("feature interaction vs native glm", { + training <- createDataFrame(iris) + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) +}) From d95288ef837ac0771d83a666325b740d1fc1b9d6 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 22 May 2016 17:38:02 -0700 Subject: [PATCH 03/15] fix post rebase, add new wrapper for new API --- R/pkg/R/SQLContext.R | 49 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index d62066a8ee53..4614461ae97e 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -35,6 +35,8 @@ getInternalType <- function(x) { POSIXlt = "timestamp", POSIXct = "timestamp", stop(paste("Unsupported type for SparkDataFrame:", class(x)))) +} + #' Temporary function to reroute old S3 Method call to new #' We need to check the class of x to ensure it is SQLContext before dispatching dispatchFunc <- function(newFuncSig, x, ...) { @@ -198,7 +200,12 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { #' @rdname createDataFrame #' @aliases createDataFrame #' @export -as.DataFrame <- function(data, schema = NULL, samplingRatio = 1.0) { +as.DataFrame <- function(x, ...) { + dispatchFunc("as.DataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) +} + +as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { + sqlContext <- getSqlContext() createDataFrame(data, schema, samplingRatio) } @@ -244,12 +251,19 @@ setMethod("toDF", signature(x = "RDD"), #' df <- read.json(sqlContext, path) #' df <- jsonFile(sqlContext, path) #' } -read.json <- function(sqlContext, path) { + +read.json <- function(x, ...) { + dispatchFunc("read.json(path)", x, ...) +} + +read.json.default <- function(path) { + sqlContext <- getSqlContext() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sqlContext, "read") sdf <- callJMethod(read, "json", paths) dataFrame(sdf) +} #' @rdname read.json #' @name jsonFile @@ -308,7 +322,13 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @rdname read.parquet #' @name read.parquet #' @export -read.parquet <- function(sqlContext, path) { + +read.parquet <- function(x, ...) { + dispatchFunc("parquetFile(...)", x, ...) +} + +read.parquet.default <- function(path) { + sqlContext <- getSqlContext() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sqlContext, "read") @@ -350,7 +370,13 @@ parquetFile.default <- function(...) { #' path <- "path/to/file.txt" #' df <- read.text(sqlContext, path) #' } -read.text <- function(sqlContext, path) { + +read.text <- function(x, ...) { + dispatchFunc("read.text(path)", x, ...) +} + +read.text.default <- function(path) { + sqlContext <- getSqlContext() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sqlContext, "read") @@ -407,7 +433,12 @@ sql.default <- function(sqlQuery) { #' new_df <- tableToDF(sqlContext, "table") #' } -tableToDF <- function(sqlContext, tableName) { +tableToDF <- function(x, ...) { + dispatchFunc("tableToDF(tableName)", x, ...) +} + +tableToDF.default <- function(sqlContext, tableName) { + sqlContext <- getSqlContext() sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) } @@ -679,7 +710,6 @@ createExternalTable.default <- function(tableName, path = NULL, source = NULL, . #' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash #' your external database systems. #' -#' @param sqlContext SQLContext to use #' @param url JDBC database url of the form `jdbc:subprotocol:subname` #' @param tableName the name of the table in the external database #' @param partitionColumn the name of a column of integral type that will be used for partitioning @@ -699,12 +729,12 @@ createExternalTable.default <- function(tableName, path = NULL, source = NULL, . #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" -#' df <- read.jdbc(sqlContext, jdbcUrl, "table", predicates = list("field<=123"), user = "username") -#' df2 <- read.jdbc(sqlContext, jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, +#' df <- read.jdbc(jdbcUrl, "table", predicates = list("field<=123"), user = "username") +#' df2 <- read.jdbc(jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, #' upperBound = 10000, user = "username", password = "password") #' } -read.jdbc <- function(sqlContext, url, tableName, +read.jdbc <- function(url, tableName, partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, numPartitions = 0L, predicates = list(), ...) { jprops <- varargsToJProperties(...) @@ -712,6 +742,7 @@ read.jdbc <- function(sqlContext, url, tableName, read <- callJMethod(sqlContext, "read") if (!is.null(partitionColumn)) { if (is.null(numPartitions) || numPartitions == 0) { + sqlContext <- getSqlContext() sc <- callJMethod(sqlContext, "sparkContext") numPartitions <- callJMethod(sc, "defaultParallelism") } else { From 7fd5ed65a7dcbacb023350e7f0b6d2f3bc7e6a42 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 23 May 2016 16:33:35 -0700 Subject: [PATCH 04/15] fix tests --- R/pkg/R/SQLContext.R | 22 +- R/pkg/inst/tests/testthat/test_context.R | 2 +- R/pkg/inst/tests/testthat/test_mllib.R | 30 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 327 ++++++++++++---------- 4 files changed, 199 insertions(+), 182 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4614461ae97e..97d4e98f8ee6 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -45,7 +45,9 @@ dispatchFunc <- function(newFuncSig, x, ...) { # Strip sqlContext from list of parameters and then pass the rest along. # In the following, if '&' is used instead of '&&', it warns about # "the condition has length > 1 and only the first element will be used" - if (class(x) == "jobj" && + if (missing(x) && length(list(...)) == 0) { + f() + } else if (class(x) == "jobj" && getClassName.jobj(x) == "org.apache.spark.sql.SQLContext") { .Deprecated(newFuncSig, old = paste0(funcName, "(sqlContext...)")) f(...) @@ -205,7 +207,6 @@ as.DataFrame <- function(x, ...) { } as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { - sqlContext <- getSqlContext() createDataFrame(data, schema, samplingRatio) } @@ -228,8 +229,7 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlContext <- getSqlContext() - createDataFrame(sqlContext, x, ...) + createDataFrame(x, ...) }) #' Create a SparkDataFrame from a JSON file. @@ -273,9 +273,8 @@ jsonFile <- function(x, ...) { } jsonFile.default <- function(path) { - sqlContext <- getSqlContext() .Deprecated("read.json") - read.json(sqlContext, path) + read.json(path) } @@ -324,7 +323,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @export read.parquet <- function(x, ...) { - dispatchFunc("parquetFile(...)", x, ...) + dispatchFunc("read.parquet(...)", x, ...) } read.parquet.default <- function(path) { @@ -345,9 +344,8 @@ parquetFile <- function(x, ...) { } parquetFile.default <- function(...) { - sqlContext <- getSqlContext() .Deprecated("read.parquet") - read.parquet(sqlContext, unlist(list(...))) + read.parquet(unlist(list(...))) } #' Create a SparkDataFrame from a text file. @@ -437,7 +435,7 @@ tableToDF <- function(x, ...) { dispatchFunc("tableToDF(tableName)", x, ...) } -tableToDF.default <- function(sqlContext, tableName) { +tableToDF.default <- function(tableName) { sqlContext <- getSqlContext() sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) @@ -566,8 +564,8 @@ uncacheTable.default <- function(tableName) { #' clearCache(sqlContext) #' } -clearCache <- function(x, ...) { - dispatchFunc("clearCache()", x, ...) +clearCache <- function() { + dispatchFunc("clearCache()") } clearCache.default <- function() { diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 0e5e15c0a96c..94bae3c15296 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -54,7 +54,7 @@ test_that("repeatedly starting and stopping SparkR SQL", { for (i in 1:4) { sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) - df <- createDataFrame(sqlContext, data.frame(a = 1:20)) + df <- createDataFrame(data.frame(a = 1:20)) expect_equal(count(df), 20) sparkR.stop() } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 5f8a27d4e094..59ef15c1e9fd 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -26,7 +26,7 @@ sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) test_that("formula of spark.glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm model <- spark.glm(training, Sepal_Width ~ . - Species + 0) @@ -41,7 +41,7 @@ test_that("formula of spark.glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) # glm should work with long formula - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -53,7 +53,7 @@ test_that("formula of spark.glm", { }) test_that("spark.glm and predict", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) prediction <- predict(model, training) @@ -80,7 +80,7 @@ test_that("spark.glm and predict", { test_that("spark.glm summary", { # gaussian family - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) @@ -99,7 +99,7 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # binomial family - df <- suppressWarnings(createDataFrame(sqlContext, iris)) + df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, family = binomial(link = "logit"))) @@ -128,7 +128,7 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -157,7 +157,7 @@ test_that("spark.glm save/load", { test_that("formula of glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) vals <- collect(select(predict(model, training), "prediction")) @@ -171,7 +171,7 @@ test_that("formula of glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) # glm should work with long formula - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -183,7 +183,7 @@ test_that("formula of glm", { }) test_that("glm and predict", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) prediction <- predict(model, training) @@ -210,7 +210,7 @@ test_that("glm and predict", { test_that("glm summary", { # gaussian family - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) @@ -229,7 +229,7 @@ test_that("glm summary", { expect_equal(stats$aic, rStats$aic) # binomial family - df <- suppressWarnings(createDataFrame(sqlContext, iris)) + df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = binomial(link = "logit"))) @@ -258,7 +258,7 @@ test_that("glm summary", { }) test_that("glm save/load", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) @@ -287,7 +287,7 @@ test_that("glm save/load", { test_that("spark.kmeans", { newIris <- iris newIris$Species <- NULL - training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + training <- suppressWarnings(createDataFrame(newIris)) take(training, 1) @@ -365,7 +365,7 @@ test_that("spark.naiveBayes", { t <- as.data.frame(Titanic) t1 <- t[t$Freq > 0, -5] - df <- suppressWarnings(createDataFrame(sqlContext, t1)) + df <- suppressWarnings(createDataFrame(t1)) m <- spark.naiveBayes(df, Survived ~ .) s <- summary(m) expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) @@ -420,7 +420,7 @@ test_that("spark.survreg", { # data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) - df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) + df <- createDataFrame(data, c("time", "status", "x", "sex")) model <- spark.survreg(df, Surv(time, status) ~ x + sex) stats <- summary(model) coefs <- as.vector(stats$coefficients[, 1]) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6a99b43e5aa5..6dacf670b300 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -99,8 +99,8 @@ test_that("structType and structField", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) + df <- createDataFrame(rdd, list("a", "b")) + dfAsDF <- as.DataFrame(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") expect_is(dfAsDF, "SparkDataFrame") expect_equal(count(df), 10) @@ -116,8 +116,8 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) - df <- createDataFrame(sqlContext, rdd) - dfAsDF <- as.DataFrame(sqlContext, rdd) + df <- createDataFrame(rdd) + dfAsDF <- as.DataFrame(rdd) expect_is(df, "SparkDataFrame") expect_is(dfAsDF, "SparkDataFrame") expect_equal(columns(df), c("_1", "_2")) @@ -125,13 +125,13 @@ test_that("create DataFrame from RDD", { schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlContext, rdd, schema) + df <- createDataFrame(rdd, schema) expect_is(df, "SparkDataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlContext, rdd) + df <- createDataFrame(rdd) expect_is(df, "SparkDataFrame") expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) @@ -139,9 +139,9 @@ test_that("create DataFrame from RDD", { schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) - df <- read.df(sqlContext, jsonPathNa, "json", schema) - df2 <- createDataFrame(sqlContext, toRDD(df), schema) - df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) + df <- read.df(jsonPathNa, "json", schema) + df2 <- createDataFrame(toRDD(df), schema) + df2AsDF <- as.DataFrame(toRDD(df), schema) expect_equal(columns(df2), c("name", "age", "height")) expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) @@ -154,7 +154,7 @@ test_that("create DataFrame from RDD", { localDF <- data.frame(name = c("John", "Smith", "Sarah"), age = c(19L, 23L, 18L), height = c(176.5, 181.4, 173.7)) - df <- createDataFrame(sqlContext, localDF, schema) + df <- createDataFrame(localDF, schema) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) expect_equal(columns(df), c("name", "age", "height")) @@ -180,37 +180,37 @@ test_that("create DataFrame from RDD", { test_that("convert NAs to null type in DataFrames", { rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) + df <- createDataFrame(rdd, list("a", "b")) expect_true(is.na(collect(df)[2, "a"])) expect_equal(collect(df)[2, "b"], 4L) l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(df)[2, "x"], 1L) expect_true(is.na(collect(df)[2, "y"])) rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) + df <- createDataFrame(rdd, list("a", "b")) expect_true(is.na(collect(df)[2, "a"])) expect_equal(collect(df)[2, "b"], 4) l <- data.frame(x = 1, y = c(1, NA_real_, 3)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(df)[2, "x"], 1) expect_true(is.na(collect(df)[2, "y"])) l <- list("a", "b", NA, "d") - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_true(is.na(collect(df)[3, "_1"])) expect_equal(collect(df)[4, "_1"], "d") l <- list("a", "b", NA_character_, "d") - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_true(is.na(collect(df)[3, "_1"])) expect_equal(collect(df)[4, "_1"], "d") l <- list(TRUE, FALSE, NA, TRUE) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_true(is.na(collect(df)[3, "_1"])) expect_equal(collect(df)[4, "_1"], TRUE) }) @@ -244,40 +244,40 @@ test_that("toDF", { test_that("create DataFrame from list or data.frame", { l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlContext, l, c("a", "b")) + df <- createDataFrame(l, c("a", "b")) expect_equal(columns(df), c("a", "b")) l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(columns(df), c("a", "b")) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlContext, ldf) + df <- createDataFrame(ldf) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) - irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) + irisdf <- suppressWarnings(createDataFrame(iris)) iris_collected <- collect(irisdf) expect_equivalent(iris_collected[, -5], iris[, -5]) expect_equal(iris_collected$Species, as.character(iris$Species)) - mtcarsdf <- createDataFrame(sqlContext, mtcars) + mtcarsdf <- createDataFrame(mtcars) expect_equivalent(collect(mtcarsdf), mtcars) bytes <- as.raw(c(1, 2, 3)) - df <- createDataFrame(sqlContext, list(list(bytes))) + df <- createDataFrame(list(list(bytes))) expect_equal(collect(df)[[1]][[1]], bytes) }) test_that("create DataFrame with different data types", { l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlContext, list(l)) + df <- createDataFrame(list(l)) expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), c("d", "string"), c("e", "date"), c("f", "timestamp"))) expect_equal(count(df), 1) @@ -291,7 +291,7 @@ test_that("create DataFrame with complex types", { s <- listToStruct(list(a = "aa", b = 3L)) l <- list(as.list(1:10), list("a", "b"), e, s) - df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) + df <- createDataFrame(list(l), c("a", "b", "c", "d")) expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), c("c", "map"), @@ -318,7 +318,7 @@ test_that("create DataFrame from a data.frame with complex types", { ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) - sdf <- createDataFrame(sqlContext, ldf) + sdf <- createDataFrame(ldf) collected <- collect(sdf) expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) @@ -334,7 +334,7 @@ writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { # ArrayType - df <- read.json(sqlContext, complexTypeJsonPath) + df <- read.json(complexTypeJsonPath) ldf <- collect(df) expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) @@ -346,7 +346,7 @@ test_that("Collect DataFrame with complex types", { # MapType schema <- structType(structField("name", "string"), structField("info", "map")) - df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + df <- read.df(mapTypeJsonPath, "json", schema) expect_equal(dtypes(df), list(c("name", "string"), c("info", "map"))) ldf <- collect(df) @@ -360,7 +360,7 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) # StructType - df <- read.json(sqlContext, mapTypeJsonPath) + df <- read.json(mapTypeJsonPath) expect_equal(dtypes(df), list(c("info", "struct"), c("name", "string"))) ldf <- collect(df) @@ -376,7 +376,7 @@ test_that("Collect DataFrame with complex types", { test_that("read/write json files", { # Test read.df - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -384,17 +384,17 @@ test_that("read/write json files", { schema <- structType(structField("name", type = "string"), structField("age", type = "double")) - df1 <- read.df(sqlContext, jsonPath, "json", schema) + df1 <- read.df(jsonPath, "json", schema) expect_is(df1, "SparkDataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Test loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) + df2 <- loadDF(jsonPath, "json", schema) expect_is(df2, "SparkDataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) # Test read.json - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -407,11 +407,11 @@ test_that("read/write json files", { write.json(df, jsonPath3) # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(sqlContext, c(jsonPath2, jsonPath3)) + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) expect_is(jsonDF1, "SparkDataFrame") expect_equal(count(jsonDF1), 6) # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath2, jsonPath3))) + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) expect_is(jsonDF2, "SparkDataFrame") expect_equal(count(jsonDF2), 6) @@ -433,82 +433,82 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test cache, uncache and clearCache", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) registerTempTable(df, "table1") - cacheTable(sqlContext, "table1") - uncacheTable(sqlContext, "table1") - clearCache(sqlContext) - dropTempTable(sqlContext, "table1") + cacheTable("table1") + uncacheTable("table1") + clearCache() + dropTempTable("table1") }) test_that("test tableNames and tables", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) registerTempTable(df, "table1") - expect_equal(length(tableNames(sqlContext)), 1) - df <- tables(sqlContext) + expect_equal(length(tableNames()), 1) + df <- tables() expect_equal(count(df), 1) - dropTempTable(sqlContext, "table1") + dropTempTable("table1") }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) registerTempTable(df, "table1") - newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + newdf <- sql("SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) - dropTempTable(sqlContext, "table1") + dropTempTable("table1") }) test_that("insertInto() on a registered table", { - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(sqlContext, parquetPath, "parquet") + dfParquet <- read.df(parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") + df2 <- read.df(jsonPath2, "json") write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") + dfParquet2 <- read.df(parquetPath2, "parquet") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_equal(count(sql(sqlContext, "select * from table1")), 5) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") - dropTempTable(sqlContext, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + dropTempTable("table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql(sqlContext, "select * from table1")), 2) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") - dropTempTable(sqlContext, "table1") + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + dropTempTable("table1") unlink(jsonPath2) unlink(parquetPath2) }) test_that("tableToDF() returns a new DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) registerTempTable(df, "table1") - tabledf <- tableToDF(sqlContext, "table1") + tabledf <- tableToDF("table1") expect_is(tabledf, "SparkDataFrame") expect_equal(count(tabledf), 3) - tabledf2 <- tableToDF(sqlContext, "table1") + tabledf2 <- tableToDF("table1") expect_equal(count(tabledf2), 3) - dropTempTable(sqlContext, "table1") + dropTempTable("table1") }) test_that("toRDD() returns an RRDD", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) @@ -530,7 +530,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) @@ -548,7 +548,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) @@ -559,7 +559,7 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row @@ -571,7 +571,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { }) test_that("collect() returns a data.frame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(names(rdf)[1], "age") @@ -587,20 +587,20 @@ test_that("collect() returns a data.frame", { expect_equal(ncol(rdf), 2) # collect() correctly handles multiple columns with same name - df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + df <- createDataFrame(list(list(1, 2)), schema = c("name", "name")) ldf <- collect(df) expect_equal(names(ldf), c("name", "name")) }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) dfLimited <- limit(df, 2) expect_is(dfLimited, "SparkDataFrame") expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(nrow(collect(df)), nrow(take(df, 10))) expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) @@ -614,7 +614,7 @@ test_that("collect() support Unicode characters", { jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPath) - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(rdf$name[1], markUtf8("안녕하세요")) @@ -622,12 +622,12 @@ test_that("collect() support Unicode characters", { expect_equal(rdf$name[3], markUtf8("こんにちは")) expect_equal(rdf$name[4], markUtf8("Xin chào")) - df1 <- createDataFrame(sqlContext, rdf) + df1 <- createDataFrame(rdf) expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -644,7 +644,7 @@ test_that("multiple pipeline transformations result in an RDD with the correct v }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -663,7 +663,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testSchema <- schema(df) expect_equal(length(testSchema$fields()), 2) expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") @@ -684,7 +684,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("names() colnames() set the column names", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) names(df) <- c("col1", "col2") expect_equal(colnames(df)[2], "col2") @@ -699,7 +699,7 @@ test_that("names() colnames() set the column names", { expect_error(colnames(df) <- c("1", NA), "Column names cannot be NA.") # Note: if this test is broken, remove check for "." character on colnames<- method - irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) + irisDF <- suppressWarnings(createDataFrame(iris)) expect_equal(names(irisDF)[1], "Sepal_Length") # Test base::colnames base::names @@ -715,7 +715,7 @@ test_that("names() colnames() set the column names", { }) test_that("head() and first() return the correct data", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testHead <- head(df) expect_equal(nrow(testHead), 3) expect_equal(ncol(testHead), 2) @@ -748,7 +748,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPathWithDup) - df <- read.json(sqlContext, jsonPathWithDup) + df <- read.json(jsonPathWithDup) uniques <- distinct(df) expect_is(uniques, "SparkDataFrame") expect_equal(count(uniques), 3) @@ -759,7 +759,6 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { # Test dropDuplicates() df <- createDataFrame( - sqlContext, list( list(2, 1, 2), list(1, 1, 1), list(1, 2, 1), list(2, 1, 2), @@ -795,7 +794,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { }) test_that("sample on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "SparkDataFrame") @@ -817,7 +816,7 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(read.json(sqlContext, jsonPath), "name", "age") + df <- select(read.json(jsonPath), "name", "age") expect_is(df$name, "Column") expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") @@ -846,7 +845,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) df1 <- select(df, "name") expect_equal(columns(df1), c("name")) expect_equal(count(df1), 3) @@ -869,7 +868,7 @@ test_that("select with column", { }) test_that("drop column", { - df <- select(read.json(sqlContext, jsonPath), "name", "age") + df <- select(read.json(jsonPath), "name", "age") df1 <- drop(df, "name") expect_equal(columns(df1), c("age")) @@ -891,7 +890,7 @@ test_that("drop column", { test_that("subsetting", { # read.json returns columns in random order - df <- select(read.json(sqlContext, jsonPath), "name", "age") + df <- select(read.json(jsonPath), "name", "age") filtered <- df[df$age > 20, ] expect_equal(count(filtered), 1) expect_equal(columns(filtered), c("name", "age")) @@ -928,7 +927,7 @@ test_that("subsetting", { }) test_that("selectExpr() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) selected <- selectExpr(df, "age * 2") expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -939,12 +938,12 @@ test_that("selectExpr() on a DataFrame", { }) test_that("expr() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) }) test_that("column calculation", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -1025,7 +1024,7 @@ test_that("column functions", { expect_equal(class(rank())[[1]], "Column") expect_equal(rank(1:3), as.numeric(c(1:3))) - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) expect_equal(collect(df2)[[2, 1]], TRUE) expect_equal(collect(df2)[[2, 2]], FALSE) @@ -1044,11 +1043,11 @@ test_that("column functions", { expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) - df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") # Test array_contains() and sort_array() - df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1061,8 +1060,7 @@ test_that("column functions", { expect_equal(length(lag(ldeaths, 12)), 72) # Test struct() - df <- createDataFrame(sqlContext, - list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) result <- collect(select(df, struct("a", "c"))) expected <- data.frame(row.names = 1:2) @@ -1078,15 +1076,14 @@ test_that("column functions", { # Test encode(), decode() bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) - df <- createDataFrame(sqlContext, - list(list(markUtf8("大千世界"), "utf-8", bytes)), + df <- createDataFrame(list(list(markUtf8("大千世界"), "utf-8", bytes)), schema = c("a", "b", "c")) result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) expect_equal(result[[1]][[1]], bytes) expect_equal(result[[2]], markUtf8("大千世界")) # Test first(), last() - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(collect(select(df, first(df$age)))[[1]], NA) expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) expect_equal(collect(select(df, first("age")))[[1]], NA) @@ -1097,7 +1094,7 @@ test_that("column functions", { expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) # Test bround() - df <- createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5))) + df <- createDataFrame(data.frame(x = c(2.5, 3.5))) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) }) @@ -1109,7 +1106,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPathWithDup) - df <- read.json(sqlContext, jsonPathWithDup) + df <- read.json(jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -1130,7 +1127,7 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -1150,14 +1147,14 @@ test_that("string operators", { expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") l2 <- list(list(a = "aaads")) - df2 <- createDataFrame(sqlContext, l2) + df2 <- createDataFrame(l2) expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint l3 <- list(list(a = "a.b.c.d")) - df3 <- createDataFrame(sqlContext, l3) + df3 <- createDataFrame(l3) expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") @@ -1169,7 +1166,7 @@ test_that("date functions on a DataFrame", { l <- list(list(a = 1L, b = as.Date("2012-12-13")), list(a = 2L, b = as.Date("2013-12-14")), list(a = 3L, b = as.Date("2014-12-15"))) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) @@ -1189,7 +1186,7 @@ test_that("date functions on a DataFrame", { l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) - df2 <- createDataFrame(sqlContext, l2) + df2 <- createDataFrame(l2) expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], @@ -1201,7 +1198,7 @@ test_that("date functions on a DataFrame", { expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) l3 <- list(list(a = 1000), list(a = -1000)) - df3 <- createDataFrame(sqlContext, l3) + df3 <- createDataFrame(l3) result31 <- collect(select(df3, from_unixtime(df3$a))) expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), c(1, 2)) @@ -1212,13 +1209,13 @@ test_that("date functions on a DataFrame", { test_that("greatest() and least() on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) }) test_that("time windowing (window()) with all inputs", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds", "5 seconds", "0 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1227,7 +1224,7 @@ test_that("time windowing (window()) with all inputs", { }) test_that("time windowing (window()) with slide duration", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds", "2 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1236,7 +1233,7 @@ test_that("time windowing (window()) with slide duration", { }) test_that("time windowing (window()) with start time", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds", startTime = "2 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1245,7 +1242,7 @@ test_that("time windowing (window()) with start time", { }) test_that("time windowing (window()) with just window duration", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1255,7 +1252,7 @@ test_that("time windowing (window()) with just window duration", { test_that("when(), otherwise() and ifelse() on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) @@ -1263,14 +1260,14 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { test_that("when(), otherwise() and ifelse() with column on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) expect_equal(collect(select(df, otherwise(when(df$a > 1, lit(1)), lit(0))))[, 1], c(0, 1)) expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) }) test_that("group by, agg functions", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -1315,7 +1312,7 @@ test_that("group by, agg functions", { "{\"name\":\"ID2\", \"value\": \"-3\"}") jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines2, jsonPath2) - gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") + gd2 <- groupBy(read.json(jsonPath2), "name") df6 <- agg(gd2, value = "sum") df6_local <- collect(df6) expect_equal(42, df6_local[df6_local$name == "ID1", ][1, 2]) @@ -1332,7 +1329,7 @@ test_that("group by, agg functions", { "{\"name\":\"Justin\", \"age\":1}") jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines3, jsonPath3) - df8 <- read.json(sqlContext, jsonPath3) + df8 <- read.json(jsonPath3) gd3 <- groupBy(df8, "name") gd3_local <- collect(sum(gd3)) expect_equal(60, gd3_local[gd3_local$name == "Andy", ][1, 2]) @@ -1351,7 +1348,7 @@ test_that("group by, agg functions", { }) test_that("arrange() and orderBy() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) sorted <- arrange(df, df$age) expect_equal(collect(sorted)[1, 2], "Michael") @@ -1377,7 +1374,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) filtered <- filter(df, "age > 20") expect_equal(count(filtered), 1) expect_equal(collect(filtered)$name, "Andy") @@ -1400,7 +1397,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() and merge() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -1408,7 +1405,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines2, jsonPath2) - df2 <- read.json(sqlContext, jsonPath2) + df2 <- read.json(jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -1483,7 +1480,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines3, jsonPath3) - df3 <- read.json(sqlContext, jsonPath3) + df3 <- read.json(jsonPath3) expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) @@ -1493,7 +1490,7 @@ test_that("join() and merge() on a DataFrame", { }) test_that("toJSON() returns an RDD of the correct values", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") expect_equal(getSerializedMode(testRDD), "string") @@ -1501,7 +1498,7 @@ test_that("toJSON() returns an RDD of the correct values", { }) test_that("showDF()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expected <- paste("+----+-------+\n", "| age| name|\n", "+----+-------+\n", @@ -1513,19 +1510,19 @@ test_that("showDF()", { }) test_that("isLocal()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") + df2 <- read.df(jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) expect_is(unioned, "SparkDataFrame") @@ -1557,7 +1554,7 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1574,7 +1571,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate(), transform(), rename() and names()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1622,10 +1619,10 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write Parquet files", { - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") # Test write.df and read.df write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(sqlContext, parquetPath, "parquet") + df2 <- read.df(parquetPath, "parquet") expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) @@ -1634,10 +1631,10 @@ test_that("read/write Parquet files", { write.parquet(df, parquetPath2) parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(sqlContext, c(parquetPath2, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) expect_is(parquetDF, "SparkDataFrame") expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath2, parquetPath3)) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) expect_is(parquetDF2, "SparkDataFrame") expect_equal(count(parquetDF2), count(df) * 2) @@ -1654,7 +1651,7 @@ test_that("read/write Parquet files", { test_that("read/write text files", { # Test write.df and read.df - df <- read.df(sqlContext, jsonPath, "text") + df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") expect_equal(colnames(df), c("value")) expect_equal(count(df), 3) @@ -1664,7 +1661,7 @@ test_that("read/write text files", { # Test write.text and read.text textPath2 <- tempfile(pattern = "textPath2", fileext = ".txt") write.text(df, textPath2) - df2 <- read.text(sqlContext, c(textPath, textPath2)) + df2 <- read.text(c(textPath, textPath2)) expect_is(df2, "SparkDataFrame") expect_equal(colnames(df2), c("value")) expect_equal(count(df2), count(df) * 2) @@ -1674,7 +1671,7 @@ test_that("read/write text files", { }) test_that("describe() and summarize() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") @@ -1692,7 +1689,7 @@ test_that("describe() and summarize() on a DataFrame", { }) test_that("dropna() and na.omit() on a DataFrame", { - df <- read.json(sqlContext, jsonPathNa) + df <- read.json(jsonPathNa) rows <- collect(df) # drop with columns @@ -1778,7 +1775,7 @@ test_that("dropna() and na.omit() on a DataFrame", { }) test_that("fillna() on a DataFrame", { - df <- read.json(sqlContext, jsonPathNa) + df <- read.json(jsonPathNa) rows <- collect(df) # fill with value @@ -1829,7 +1826,7 @@ test_that("crosstab() on a DataFrame", { test_that("cov() and corr() on a DataFrame", { l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) - df <- createDataFrame(sqlContext, l, c("singles", "doubles")) + df <- createDataFrame(l, c("singles", "doubles")) result <- cov(df, "singles", "doubles") expect_true(abs(result - 55.0 / 3) < 1e-12) @@ -1847,7 +1844,7 @@ test_that("freqItems() on a DataFrame", { rdf <- data.frame(numbers = input, letters = as.character(input), negDoubles = input * -1.0, stringsAsFactors = F) rdf[ input %% 3 == 0, ] <- c(1, "1", -1) - df <- createDataFrame(sqlContext, rdf) + df <- createDataFrame(rdf) multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) expect_true("1" %in% multiColResults$letters[[1]]) @@ -1857,7 +1854,7 @@ test_that("freqItems() on a DataFrame", { l <- lapply(c(0:99), function(i) { if (i %% 2 == 0) { list(1L, -1.0) } else { list(i, i * -1.0) }}) - df <- createDataFrame(sqlContext, l, c("a", "b")) + df <- createDataFrame(l, c("a", "b")) result <- freqItems(df, c("a", "b"), 0.4) expect_identical(result[[1]], list(list(1L, 99L))) expect_identical(result[[2]], list(list(-1, -99))) @@ -1865,7 +1862,7 @@ test_that("freqItems() on a DataFrame", { test_that("sampleBy() on a DataFrame", { l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) - df <- createDataFrame(sqlContext, l, "key") + df <- createDataFrame(l, "key") fractions <- list("0" = 0.1, "1" = 0.2) sample <- sampleBy(df, "key", fractions, 0) result <- collect(orderBy(count(groupBy(sample, "key")), "key")) @@ -1875,19 +1872,19 @@ test_that("sampleBy() on a DataFrame", { test_that("approxQuantile() on a DataFrame", { l <- lapply(c(0:99), function(i) { i }) - df <- createDataFrame(sqlContext, l, "key") + df <- createDataFrame(l, "key") quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) expect_equal(quantiles[[1]], 50) expect_equal(quantiles[[2]], 80) }) test_that("SQL error message is returned from JVM", { - retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) }) -irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) +irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF), collect(irisDF)) @@ -1899,7 +1896,7 @@ test_that("Method as.data.frame as a synonym for collect()", { }) test_that("attach() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_error(age) attach(df) expect_is(age, "SparkDataFrame") @@ -1919,7 +1916,7 @@ test_that("attach() on a DataFrame", { }) test_that("with() on a DataFrame", { - df <- suppressWarnings(createDataFrame(sqlContext, iris)) + df <- suppressWarnings(createDataFrame(iris)) expect_error(Sepal_Length) sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") @@ -1939,15 +1936,15 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { structField("c4", "timestamp")) # Test primitive types - DF <- createDataFrame(sqlContext, data, schema) + DF <- createDataFrame(data, schema) expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) # Test complex types - x <- createDataFrame(sqlContext, list(list(as.environment( + x <- createDataFrame(list(list(as.environment( list("a" = "b", "c" = "d", "e" = "f"))))) expect_equal(coltypes(x), "map") - df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") + df <- selectExpr(read.json(jsonPath), "name", "(age * 1.21) as age") expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) @@ -1971,7 +1968,7 @@ test_that("Method str()", { iris2 <- iris colnames(iris2) <- c("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species") iris2$col <- TRUE - irisDF2 <- createDataFrame(sqlContext, iris2) + irisDF2 <- createDataFrame(iris2) out <- capture.output(str(irisDF2)) expect_equal(length(out), 7) @@ -1989,7 +1986,7 @@ test_that("Method str()", { # number of returned rows x <- runif(200, 1, 10) df <- data.frame(t(as.matrix(data.frame(x, x, x, x, x, x, x, x, x)))) - DF <- createDataFrame(sqlContext, df) + DF <- createDataFrame(df) out <- capture.output(str(DF)) expect_equal(length(out), 103) @@ -2039,13 +2036,12 @@ test_that("Histogram", { histogram(irisDF, "Sepal_Width", 12)$counts), T) # Test when there are zero counts - df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100))) + df <- as.DataFrame(data.frame(x = c(1, 2, 3, 4, 100))) expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1)) }) test_that("dapply() and dapplyCollect() on a DataFrame", { - df <- createDataFrame ( - sqlContext, + df <- createDataFrame( list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")), c("a", "b", "c")) ldf <- collect(df) @@ -2102,8 +2098,7 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("repartition by columns on DataFrame", { - df <- createDataFrame ( - sqlContext, + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -2173,6 +2168,30 @@ test_that("Window functions on a DataFrame", { expect_equal(result, expected) }) +test_that("createDataFrame sqlContext parameter backward compatibility", { + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlContext, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + df2 <- createDataFrame(sqlContext, iris) + expect_equal(count(df2), 150) + expect_equal(ncol(df2), 5) + + df3 <- read.df(sqlContext, jsonPath, "json") + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + + before <- createDataFrame(sqlContext, iris) + after <- createDataFrame(iris) + expect_equal(collect(before), collect(after)) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From 4884e56efe19233ce35574d2aeeebb7d3213e20e Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 23 May 2016 16:55:25 -0700 Subject: [PATCH 05/15] suppress warnings for tests --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6dacf670b300..c958d815a74d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2172,23 +2172,23 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlContext, ldf) + df <- suppressWarnings(createDataFrame(sqlContext, ldf)) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) - df2 <- createDataFrame(sqlContext, iris) + df2 <- suppressWarnings(createDataFrame(sqlContext, iris)) expect_equal(count(df2), 150) expect_equal(ncol(df2), 5) - df3 <- read.df(sqlContext, jsonPath, "json") + df3 <- suppressWarnings(read.df(sqlContext, jsonPath, "json")) expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) - before <- createDataFrame(sqlContext, iris) - after <- createDataFrame(iris) + before <- suppressWarnings(createDataFrame(sqlContext, iris)) + after <- suppressWarnings(createDataFrame(iris)) expect_equal(collect(before), collect(after)) }) From d9d72cf6c9fcb06c12102679b1ab5e5c1e0965b0 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 23 May 2016 17:02:40 -0700 Subject: [PATCH 06/15] fix roxygen2 doc gen --- R/pkg/R/SQLContext.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 97d4e98f8ee6..e13f4daf7883 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -652,7 +652,7 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { dataFrame(sdf) } -#' @aliases read.df +#' @rdname read.df #' @name loadDF loadDF <- function(x, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) From 3a2e0c7919b9fdbd5558cda474368c25208856b0 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 23 May 2016 22:23:50 -0700 Subject: [PATCH 07/15] fix lint-r --- R/pkg/R/SQLContext.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e13f4daf7883..b8fefcbac31e 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -43,8 +43,6 @@ dispatchFunc <- function(newFuncSig, x, ...) { funcName <- as.character(sys.call(sys.parent())[[1]]) f <- get(paste0(funcName, ".default")) # Strip sqlContext from list of parameters and then pass the rest along. - # In the following, if '&' is used instead of '&&', it warns about - # "the condition has length > 1 and only the first element will be used" if (missing(x) && length(list(...)) == 0) { f() } else if (class(x) == "jobj" && From a9479dd3ea1f8db84ec7dd26989a0476a39419ec Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 23 May 2016 23:50:56 -0700 Subject: [PATCH 08/15] fix test with hivecontext --- R/pkg/R/SQLContext.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index b8fefcbac31e..c7320a3ec227 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -46,7 +46,8 @@ dispatchFunc <- function(newFuncSig, x, ...) { if (missing(x) && length(list(...)) == 0) { f() } else if (class(x) == "jobj" && - getClassName.jobj(x) == "org.apache.spark.sql.SQLContext") { + (getClassName.jobj(x) == "org.apache.spark.sql.SQLContext" || + getClassName.jobj(x) == "org.apache.spark.sql.hive.HiveContext")) { .Deprecated(newFuncSig, old = paste0(funcName, "(sqlContext...)")) f(...) } else { From c4fd5cdecd29fa199c66574729c1aadd127d97bf Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 00:25:52 -0700 Subject: [PATCH 09/15] rearrange code and add tag to fix doc, fix hive context tests, fix warn --- R/pkg/R/SQLContext.R | 246 ++++++++++++---------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 +- 2 files changed, 143 insertions(+), 121 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index c7320a3ec227..36177f5f4b1c 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -43,11 +43,13 @@ dispatchFunc <- function(newFuncSig, x, ...) { funcName <- as.character(sys.call(sys.parent())[[1]]) f <- get(paste0(funcName, ".default")) # Strip sqlContext from list of parameters and then pass the rest along. + contextNames <- c("org.apache.spark.sql.SQLContext", + "org.apache.spark.sql.hive.HiveContext", + "org.apache.spark.sql.hive.test.TestHiveContext") if (missing(x) && length(list(...)) == 0) { f() } else if (class(x) == "jobj" && - (getClassName.jobj(x) == "org.apache.spark.sql.SQLContext" || - getClassName.jobj(x) == "org.apache.spark.sql.hive.HiveContext")) { + any(grepl(paste(contextNames, collapse = "|"), getClassName.jobj(x)))) { .Deprecated(newFuncSig, old = paste0(funcName, "(sqlContext...)")) f(...) } else { @@ -103,7 +105,6 @@ infer_type <- function(x) { #' #' Converts R data.frame or list into SparkDataFrame. #' -#' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return a SparkDataFrame @@ -113,16 +114,14 @@ infer_type <- function(x) { #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- as.DataFrame(sqlContext, iris) -#' df2 <- as.DataFrame(sqlContext, list(3,4,5,6)) -#' df3 <- createDataFrame(sqlContext, iris) +#' df1 <- as.DataFrame(iris) +#' df2 <- as.DataFrame(list(3,4,5,6)) +#' df3 <- createDataFrame(iris) #' } +#' @name createDataFrame +#' @method createDataFrame default # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(x, ...) { - dispatchFunc("createDataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) -} - createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { sqlContext <- getSqlContext() if (is.data.frame(data)) { @@ -198,17 +197,23 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { dataFrame(sdf) } +createDataFrame <- function(x, ...) { + dispatchFunc("createDataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) +} + #' @rdname createDataFrame #' @aliases createDataFrame #' @export -as.DataFrame <- function(x, ...) { - dispatchFunc("as.DataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) -} +#' @method as.DataFrame default as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { createDataFrame(data, schema, samplingRatio) } +as.DataFrame <- function(x, ...) { + dispatchFunc("as.DataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) +} + #' toDF #' #' Converts an RDD to a SparkDataFrame by infer the types. @@ -236,24 +241,20 @@ setMethod("toDF", signature(x = "RDD"), #' Loads a JSON file (one object per line), returning the result as a SparkDataFrame #' It goes through the entire dataset once to determine the schema. #' -#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.json -#' @name read.json #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(path) +#' df <- jsonFile(path) #' } - -read.json <- function(x, ...) { - dispatchFunc("read.json(path)", x, ...) -} +#' @name read.json +#' @method read.json default read.json.default <- function(path) { sqlContext <- getSqlContext() @@ -264,18 +265,23 @@ read.json.default <- function(path) { dataFrame(sdf) } +read.json <- function(x, ...) { + dispatchFunc("read.json(path)", x, ...) +} + #' @rdname read.json #' @name jsonFile #' @export -jsonFile <- function(x, ...) { - dispatchFunc("jsonFile(path)", x, ...) -} +#' @method jsonFile default jsonFile.default <- function(path) { .Deprecated("read.json") read.json(path) } +jsonFile <- function(x, ...) { + dispatchFunc("jsonFile(path)", x, ...) +} #' JSON RDD #' @@ -314,16 +320,12 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' #' Loads a Parquet file, returning the result as a SparkDataFrame. #' -#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @name read.parquet #' @export - -read.parquet <- function(x, ...) { - dispatchFunc("read.parquet(...)", x, ...) -} +#' @name read.parquet +#' @method read.parquet default read.parquet.default <- function(path) { sqlContext <- getSqlContext() @@ -334,19 +336,24 @@ read.parquet.default <- function(path) { dataFrame(sdf) } +read.parquet <- function(x, ...) { + dispatchFunc("read.parquet(...)", x, ...) +} + #' @rdname read.parquet #' @name parquetFile #' @export -# TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(x, ...) { - dispatchFunc("parquetFile(...)", x, ...) -} +#' @method parquetFile default parquetFile.default <- function(...) { .Deprecated("read.parquet") read.parquet(unlist(list(...))) } +parquetFile <- function(x, ...) { + dispatchFunc("parquetFile(...)", x, ...) +} + #' Create a SparkDataFrame from a text file. #' #' Loads a text file and returns a SparkDataFrame with a single string column named "value". @@ -354,23 +361,19 @@ parquetFile.default <- function(...) { #' ignored in the resulting DataFrame. #' Each line in the text file is a new row in the resulting SparkDataFrame. #' -#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.text -#' @name read.text #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.txt" -#' df <- read.text(sqlContext, path) +#' df <- read.text(path) #' } - -read.text <- function(x, ...) { - dispatchFunc("read.text(path)", x, ...) -} +#' @name read.text +#' @method read.text default read.text.default <- function(path) { sqlContext <- getSqlContext() @@ -381,27 +384,29 @@ read.text.default <- function(path) { dataFrame(sdf) } +read.text <- function(x, ...) { + dispatchFunc("read.text(path)", x, ...) +} + #' SQL Query #' #' Executes a SQL query using Spark, returning the result as a SparkDataFrame. #' -#' @param sqlContext SQLContext to use #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame +#' @rdname sql #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' registerTempTable(df, "table") -#' new_df <- sql(sqlContext, "SELECT * FROM table") +#' new_df <- sql("SELECT * FROM table") #' } - -sql <- function(x, ...) { - dispatchFunc("sql(sqlQuery)", x, ...) -} +#' @name sql +#' @method sql default sql.default <- function(sqlQuery) { sqlContext <- getSqlContext() @@ -409,12 +414,15 @@ sql.default <- function(sqlQuery) { dataFrame(sdf) } +sql <- function(x, ...) { + dispatchFunc("sql(sqlQuery)", x, ...) +} + #' Create a SparkDataFrame from a SparkSQL Table #' #' Returns the specified Table as a SparkDataFrame. The Table must have already been registered #' in the SQLContext. #' -#' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a SparkDataFrame. #' @return SparkDataFrame #' @rdname tableToDF @@ -425,16 +433,13 @@ sql.default <- function(sqlQuery) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' registerTempTable(df, "table") -#' new_df <- tableToDF(sqlContext, "table") +#' new_df <- tableToDF("table") #' } +#' @note since 2.0.0 -tableToDF <- function(x, ...) { - dispatchFunc("tableToDF(tableName)", x, ...) -} - -tableToDF.default <- function(tableName) { +tableToDF <- function(tableName) { sqlContext <- getSqlContext() sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) @@ -444,20 +449,18 @@ tableToDF.default <- function(tableName) { #' #' Returns a SparkDataFrame containing names of tables in the given database. #' -#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a SparkDataFrame +#' @rdname tables #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' tables(sqlContext, "hive") +#' tables("hive") #' } - -tables <- function(x, ...) { - dispatchFunc("tables(databaseName = NULL)", x, ...) -} +#' @name tables +#' @method tables default tables.default <- function(databaseName = NULL) { sqlContext <- getSqlContext() @@ -469,25 +472,26 @@ tables.default <- function(databaseName = NULL) { dataFrame(jdf) } +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) +} #' Table Names #' #' Returns the names of tables in the given database as an array. #' -#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a list of table names +#' @rdname tableNames #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' tableNames(sqlContext, "hive") +#' tableNames("hive") #' } - -tableNames <- function(x, ...) { - dispatchFunc("tableNames(databaseName = NULL)", x, ...) -} +#' @name tableNames +#' @method tableNames default tableNames.default <- function(databaseName = NULL) { sqlContext <- getSqlContext() @@ -498,99 +502,108 @@ tableNames.default <- function(databaseName = NULL) { } } +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) +} #' Cache Table #' #' Caches the specified table in-memory. #' -#' @param sqlContext SQLContext to use #' @param tableName The name of the table being cached #' @return SparkDataFrame +#' @rdname cacheTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' registerTempTable(df, "table") -#' cacheTable(sqlContext, "table") +#' cacheTable("table") #' } - -cacheTable <- function(x, ...) { - dispatchFunc("cacheTable(tableName)", x, ...) -} +#' @name cacheTable +#' @method cacheTable default cacheTable.default <- function(tableName) { sqlContext <- getSqlContext() callJMethod(sqlContext, "cacheTable", tableName) } +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) +} + #' Uncache Table #' #' Removes the specified table from the in-memory cache. #' -#' @param sqlContext SQLContext to use #' @param tableName The name of the table being uncached #' @return SparkDataFrame +#' @rdname uncacheTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' registerTempTable(df, "table") -#' uncacheTable(sqlContext, "table") +#' uncacheTable("table") #' } - -uncacheTable <- function(x, ...) { - dispatchFunc("uncacheTable(tableName)", x, ...) -} +#' @name uncacheTable +#' @method uncacheTable default uncacheTable.default <- function(tableName) { sqlContext <- getSqlContext() callJMethod(sqlContext, "uncacheTable", tableName) } +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) +} + #' Clear Cache #' #' Removes all cached tables from the in-memory cache. #' -#' @param sqlContext SQLContext to use +#' @rdname clearCache +#' @export #' @examples #' \dontrun{ -#' clearCache(sqlContext) +#' clearCache() #' } - -clearCache <- function() { - dispatchFunc("clearCache()") -} +#' @name clearCache +#' @method clearCache default clearCache.default <- function() { sqlContext <- getSqlContext() callJMethod(sqlContext, "clearCache") } +clearCache <- function() { + dispatchFunc("clearCache()") +} + #' Drop Temporary Table #' #' Drops the temporary table with the given table name in the catalog. #' If the table has been cached/persisted before, it's also unpersisted. #' -#' @param sqlContext SQLContext to use #' @param tableName The name of the SparkSQL table to be dropped. +#' @rdname dropTempTable +#' @export #' @examples #' \dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df <- read.df(sqlContext, path, "parquet") +#' df <- read.df(path, "parquet") #' registerTempTable(df, "table") -#' dropTempTable(sqlContext, "table") +#' dropTempTable("table") #' } - -dropTempTable <- function(x, ...) { - dispatchFunc("dropTempTable(tableName)", x, ...) -} +#' @name dropTempTable +#' @method dropTempTable default dropTempTable.default <- function(tableName) { sqlContext <- getSqlContext() @@ -600,6 +613,10 @@ dropTempTable.default <- function(tableName) { callJMethod(sqlContext, "dropTempTable", tableName) } +dropTempTable <- function(x, ...) { + dispatchFunc("dropTempTable(tableName)", x, ...) +} + #' Load a SparkDataFrame #' #' Returns the dataset in a data source as a SparkDataFrame @@ -608,7 +625,6 @@ dropTempTable.default <- function(tableName) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlContext SQLContext to use #' @param path The path of files to load #' @param source The name of external data source #' @param schema The data schema defined in structType @@ -620,16 +636,14 @@ dropTempTable.default <- function(tableName) { #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.df(sqlContext, "path/to/file.json", source = "json") +#' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(sqlContext, mapTypeJsonPath, "json", schema) -#' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") +#' df2 <- read.df(mapTypeJsonPath, "json", schema) +#' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } - -read.df <- function(x, ...) { - dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) -} +#' @name read.df +#' @method read.df default read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { sqlContext <- getSqlContext() @@ -651,16 +665,22 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { dataFrame(sdf) } +read.df <- function(x, ...) { + dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) +} + #' @rdname read.df #' @name loadDF -loadDF <- function(x, ...) { - dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) -} +#' @method loadDF default loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { read.df(path, source, schema, ...) } +loadDF <- function(x, ...) { + dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) +} + #' Create an external table #' #' Creates an external table based on the dataset in a data source, @@ -670,22 +690,20 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlContext SQLContext to use #' @param tableName A name of the table #' @param path The path of files to load #' @param source the name of external data source #' @return SparkDataFrame +#' @rdname createExternalTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") +#' df <- sparkRSQL.createExternalTable("myjson", path="path/to/json", source="json") #' } - -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} +#' @name createExternalTable +#' @method createExternalTable default createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { sqlContext <- getSqlContext() @@ -697,6 +715,10 @@ createExternalTable.default <- function(tableName, path = NULL, source = NULL, . dataFrame(sdf) } +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c958d815a74d..767145be9f12 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -169,8 +169,8 @@ test_that("create DataFrame from RDD", { error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - df <- read.df(hiveCtx, jsonPathNa, "json", schema) + suppressWarnings(sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")) + df <- suppressWarnings(read.df(hiveCtx, jsonPathNa, "json", schema)) invisible(insertInto(df, "people")) expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, c(16)) @@ -959,30 +959,30 @@ test_that("test HiveContext", { error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) - df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + df <- suppressWarnings(createExternalTable(hiveCtx, "json", jsonPath, "json")) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) - df2 <- sql(hiveCtx, "select * from json") + df2 <- suppressWarnings(sql(hiveCtx, "select * from json")) expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) - df3 <- sql(hiveCtx, "select * from json2") + df3 <- suppressWarnings(sql(hiveCtx, "select * from json2")) expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) - df4 <- sql(hiveCtx, "select * from hivetestbl") + df4 <- suppressWarnings(sql(hiveCtx, "select * from hivetestbl")) expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) - df5 <- sql(hiveCtx, "select * from parquetest") + df5 <- suppressWarnings(sql(hiveCtx, "select * from parquetest")) expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) @@ -2141,9 +2141,9 @@ test_that("Window functions on a DataFrame", { skip("Hive is not build with SparkSQL, skipped") }) - df <- createDataFrame(hiveCtx, + df <- suppressWarnings(createDataFrame(hiveCtx, list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), - schema = c("key", "value")) + schema = c("key", "value"))) ws <- orderBy(window.partitionBy("key"), "value") result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") From 153d5e7e848bccb20f34c334efefcd9ee66957a0 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 01:08:39 -0700 Subject: [PATCH 10/15] ok one more time to fix test --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 26 ++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 767145be9f12..868f73a4ed5c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -169,12 +169,13 @@ test_that("create DataFrame from RDD", { error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) - suppressWarnings(sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")) - df <- suppressWarnings(read.df(hiveCtx, jsonPathNa, "json", schema)) + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + sql("CREATE TABLE people (name string, age double, height float)") + df <- read.df(jsonPathNa, "json", schema) invisible(insertInto(df, "people")) - expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) }) @@ -959,30 +960,31 @@ test_that("test HiveContext", { error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) - df <- suppressWarnings(createExternalTable(hiveCtx, "json", jsonPath, "json")) + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + df <- createExternalTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) - df2 <- suppressWarnings(sql(hiveCtx, "select * from json")) + df2 <- sql("select * from json") expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) - df3 <- suppressWarnings(sql(hiveCtx, "select * from json2")) + df3 <- sql("select * from json2") expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) - df4 <- suppressWarnings(sql(hiveCtx, "select * from hivetestbl")) + df4 <- sql("select * from hivetestbl") expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) - df5 <- suppressWarnings(sql(hiveCtx, "select * from parquetest")) + df5 <- sql("select * from parquetest") expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) @@ -2141,9 +2143,9 @@ test_that("Window functions on a DataFrame", { skip("Hive is not build with SparkSQL, skipped") }) - df <- suppressWarnings(createDataFrame(hiveCtx, - list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), - schema = c("key", "value"))) + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), + schema = c("key", "value")) ws <- orderBy(window.partitionBy("key"), "value") result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") From f53b148e91c58abfe40dfc17f0374ff511d0e5f1 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 02:19:59 -0700 Subject: [PATCH 11/15] fix hive context test failure more --- R/pkg/R/SQLContext.R | 1 + R/pkg/inst/tests/test_mllib.R | 69 ----------------------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 3 + 3 files changed, 4 insertions(+), 69 deletions(-) delete mode 100644 R/pkg/inst/tests/test_mllib.R diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 36177f5f4b1c..19d2b1e84fa3 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -752,6 +752,7 @@ createExternalTable <- function(x, ...) { #' df2 <- read.jdbc(jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, #' upperBound = 10000, user = "username", password = "password") #' } +#' @note since 2.0.0 read.jdbc <- function(url, tableName, partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R deleted file mode 100644 index 7cc7a4227e2d..000000000000 --- a/R/pkg/inst/tests/test_mllib.R +++ /dev/null @@ -1,69 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("MLlib functions") - -# Tests for MLlib functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) - -test_that("glm and predict", { - training <- createDataFrame(iris) - test <- select(training, "Sepal_Length") - model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") - prediction <- predict(model, test) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") -}) - -test_that("predictions match with native glm", { - training <- createDataFrame(iris) - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("dot minus and intercept vs native glm", { - training <- createDataFrame(iris) - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("feature interaction vs native glm", { - training <- createDataFrame(iris) - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("summary coefficients match with native glm", { - training <- createDataFrame(iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) - coefs <- as.vector(stats$coefficients) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) - expect_true(all( - as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) -}) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 868f73a4ed5c..2a7a3a917532 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -177,6 +177,7 @@ test_that("create DataFrame from RDD", { c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) + remove(".sparkRHivesc", envir = .sparkREnv) }) test_that("convert NAs to null type in DataFrames", { @@ -988,6 +989,7 @@ test_that("test HiveContext", { expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) + remove(".sparkRHivesc", envir = .sparkREnv) }) test_that("column operators", { @@ -2168,6 +2170,7 @@ test_that("Window functions on a DataFrame", { result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") expect_equal(result, expected) + remove(".sparkRHivesc", envir = .sparkREnv) }) test_that("createDataFrame sqlContext parameter backward compatibility", { From 98e7ab978dd1a7aba06ee47eb012c496f00cbe3b Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 20:16:49 -0700 Subject: [PATCH 12/15] review feedback --- R/pkg/R/SQLContext.R | 10 ++++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 45 +++++++++++------------ 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 19d2b1e84fa3..584bbbf0e4c2 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -38,7 +38,15 @@ getInternalType <- function(x) { } #' Temporary function to reroute old S3 Method call to new -#' We need to check the class of x to ensure it is SQLContext before dispatching +#' This function is specifically implemented to remove SQLContext from the parameter list. +#' It determines the target to route the call by checking the parent of this callsite (say 'func'). +#' The target should be called 'func.default'. +#' We need to check the class of x to ensure it is SQLContext/HiveContext before dispatching. +#' @param newFuncSig name of the function the user should call instead in the deprecation message +#' @param x the first parameter of the original call +#' @param ... the rest of parameter to pass along +#' @return whatever the target returns +#' @noRd dispatchFunc <- function(newFuncSig, x, ...) { funcName <- as.character(sys.call(sys.parent())[[1]]) f <- get(paste0(funcName, ".default")) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2a7a3a917532..2a7dacba186e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -32,6 +32,21 @@ markUtf8 <- function(s) { s } +setHiveContext <- function() { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} + +unsetHiveContext <- function() { + remove(".sparkRHivesc", envir = .sparkREnv) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -163,13 +178,7 @@ test_that("create DataFrame from RDD", { list(name = "John", age = 19L, height = 176.5)) ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + setHiveContext() sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) invisible(insertInto(df, "people")) @@ -177,6 +186,7 @@ test_that("create DataFrame from RDD", { c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) + unsetHiveContext() remove(".sparkRHivesc", envir = .sparkREnv) }) @@ -955,13 +965,7 @@ test_that("column calculation", { test_that("test HiveContext", { ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + setHiveContext() df <- createExternalTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -989,7 +993,7 @@ test_that("test HiveContext", { expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) - remove(".sparkRHivesc", envir = .sparkREnv) + unsetHiveContext() }) test_that("column operators", { @@ -2138,14 +2142,7 @@ test_that("repartition by columns on DataFrame", { test_that("Window functions on a DataFrame", { ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + setHiveContext() df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), schema = c("key", "value")) ws <- orderBy(window.partitionBy("key"), "value") @@ -2170,7 +2167,7 @@ test_that("Window functions on a DataFrame", { result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") expect_equal(result, expected) - remove(".sparkRHivesc", envir = .sparkREnv) + unsetHiveContext() }) test_that("createDataFrame sqlContext parameter backward compatibility", { From 640ffcaa2836d40337cda6c3ba1e51f30d16a44c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 20:24:08 -0700 Subject: [PATCH 13/15] fix bug --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2a7dacba186e..7edab2b5cf7d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -33,6 +33,7 @@ markUtf8 <- function(s) { } setHiveContext <- function() { + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) }, @@ -177,7 +178,6 @@ test_that("create DataFrame from RDD", { expect_equal(as.list(collect(where(df, df$name == "John"))), list(name = "John", age = 19L, height = 176.5)) - ssc <- callJMethod(sc, "sc") setHiveContext() sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) From 90641a71ff1860ddfe1a8e0bcb64cc0f0d2a56c6 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 20:39:21 -0700 Subject: [PATCH 14/15] left one line by accident --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7edab2b5cf7d..b9252cc48f8d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -187,7 +187,6 @@ test_that("create DataFrame from RDD", { expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) unsetHiveContext() - remove(".sparkRHivesc", envir = .sparkREnv) }) test_that("convert NAs to null type in DataFrames", { From f67095ef72540140aa2348b5262ffdf91685846a Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 May 2016 21:04:00 -0700 Subject: [PATCH 15/15] more fix to test --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index b9252cc48f8d..5910267aadb0 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -32,7 +32,7 @@ markUtf8 <- function(s) { s } -setHiveContext <- function() { +setHiveContext <- function(sc) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) @@ -178,7 +178,7 @@ test_that("create DataFrame from RDD", { expect_equal(as.list(collect(where(df, df$name == "John"))), list(name = "John", age = 19L, height = 176.5)) - setHiveContext() + setHiveContext(sc) sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) invisible(insertInto(df, "people")) @@ -963,8 +963,7 @@ test_that("column calculation", { }) test_that("test HiveContext", { - ssc <- callJMethod(sc, "sc") - setHiveContext() + setHiveContext(sc) df <- createExternalTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -2140,8 +2139,7 @@ test_that("repartition by columns on DataFrame", { }) test_that("Window functions on a DataFrame", { - ssc <- callJMethod(sc, "sc") - setHiveContext() + setHiveContext(sc) df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), schema = c("key", "value")) ws <- orderBy(window.partitionBy("key"), "value")