diff --git a/.gitignore b/.gitignore
index 9f8cd0b4cb232..0991976abfb8b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,6 +17,7 @@
.idea/
.idea_modules/
.project
+.pydevproject
.scala_dependencies
.settings
/lib/
@@ -77,3 +78,8 @@ spark-warehouse/
# For R session data
.RData
.RHistory
+.Rhistory
+*.Rproj
+*.Rproj.*
+
+.Rproj.user
diff --git a/R/check-cran.sh b/R/check-cran.sh
index b3a6860961c1e..5c90fd07f28e4 100755
--- a/R/check-cran.sh
+++ b/R/check-cran.sh
@@ -47,6 +47,6 @@ $FWDIR/create-docs.sh
VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'`
-"$R_SCRIPT_PATH/"R CMD check --as-cran --no-tests SparkR_"$VERSION".tar.gz
+"$R_SCRIPT_PATH/"R CMD check --as-cran SparkR_"$VERSION".tar.gz
popd > /dev/null
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index ac73d6c79891e..357ab007931f5 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -7,7 +7,7 @@ Author: The Apache Software Foundation
Maintainer: Shivaram Venkataraman
Depends:
R (>= 3.0),
- methods,
+ methods
Suggests:
testthat,
e1071,
@@ -31,6 +31,7 @@ Collate:
'context.R'
'deserialize.R'
'functions.R'
+ 'install.R'
'mllib.R'
'serialize.R'
'sparkR.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 1d74c6d95578f..aaab92f5cfc7b 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -352,3 +352,5 @@ S3method(structField, character)
S3method(structField, jobj)
S3method(structType, jobj)
S3method(structType, structField)
+
+export("install.spark")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index aa211b326a167..daf5860024af1 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -120,8 +120,9 @@ setMethod("schema",
#'
#' Print the logical and physical Catalyst plans to the console for debugging.
#'
-#' @param x A SparkDataFrame
+#' @param x a SparkDataFrame.
#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan.
+#' @param ... further arguments to be passed to or from other methods.
#' @family SparkDataFrame functions
#' @aliases explain,SparkDataFrame-method
#' @rdname explain
@@ -137,7 +138,7 @@ setMethod("schema",
#' @note explain since 1.4.0
setMethod("explain",
signature(x = "SparkDataFrame"),
- function(x, extended = FALSE) {
+ function(x, extended = FALSE, ...) {
queryExec <- callJMethod(x@sdf, "queryExecution")
if (extended) {
cat(callJMethod(queryExec, "toString"))
@@ -177,11 +178,13 @@ setMethod("isLocal",
#'
#' Print the first numRows rows of a SparkDataFrame
#'
-#' @param x A SparkDataFrame
-#' @param numRows The number of rows to print. Defaults to 20.
-#' @param truncate Whether truncate long strings. If true, strings more than 20 characters will be
-#' truncated and all cells will be aligned right
-#'
+#' @param x a SparkDataFrame.
+#' @param numRows the number of rows to print. Defaults to 20.
+#' @param truncate whether truncate long strings. If \code{TRUE}, strings more than
+#' 20 characters will be truncated. However, if set greater than zero,
+#' truncates strings longer than `truncate` characters and all cells
+#' will be aligned right.
+#' @param ... further arguments to be passed to or from other methods.
#' @family SparkDataFrame functions
#' @aliases showDF,SparkDataFrame-method
#' @rdname showDF
@@ -206,7 +209,7 @@ setMethod("showDF",
#'
#' Print the SparkDataFrame column names and types
#'
-#' @param x A SparkDataFrame
+#' @param object a SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname show
@@ -257,11 +260,11 @@ setMethod("dtypes",
})
})
-#' Column names
+#' Column Names of SparkDataFrame
#'
-#' Return all column names as a list
+#' Return all column names as a list.
#'
-#' @param x A SparkDataFrame
+#' @param x a SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname columns
@@ -318,6 +321,8 @@ setMethod("colnames",
columns(x)
})
+#' @param value a character vector. Must have the same length as the number
+#' of columns in the SparkDataFrame.
#' @rdname columns
#' @aliases colnames<-,SparkDataFrame-method
#' @name colnames<-
@@ -406,7 +411,6 @@ setMethod("coltypes",
#'
#' Set the column types of a SparkDataFrame.
#'
-#' @param x A SparkDataFrame
#' @param value A character vector with the target column types for the given
#' SparkDataFrame. Column types can be one of integer, numeric/double, character, logical, or NA
#' to keep that column as-is.
@@ -510,9 +514,10 @@ setMethod("registerTempTable",
#'
#' Insert the contents of a SparkDataFrame into a table registered in the current SparkSession.
#'
-#' @param x A SparkDataFrame
-#' @param tableName A character vector containing the name of the table
-#' @param overwrite A logical argument indicating whether or not to overwrite
+#' @param x a SparkDataFrame.
+#' @param tableName a character vector containing the name of the table.
+#' @param overwrite a logical argument indicating whether or not to overwrite.
+#' @param ... further arguments to be passed to or from other methods.
#' the existing rows in the table.
#'
#' @family SparkDataFrame functions
@@ -531,7 +536,7 @@ setMethod("registerTempTable",
#' @note insertInto since 1.4.0
setMethod("insertInto",
signature(x = "SparkDataFrame", tableName = "character"),
- function(x, tableName, overwrite = FALSE) {
+ function(x, tableName, overwrite = FALSE, ...) {
jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append"))
write <- callJMethod(x@sdf, "write")
write <- callJMethod(write, "mode", jmode)
@@ -571,7 +576,9 @@ setMethod("cache",
#' supported storage levels, refer to
#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}.
#'
-#' @param x The SparkDataFrame to persist
+#' @param x the SparkDataFrame to persist.
+#' @param newLevel storage level chosen for the persistance. See available options in
+#' the description.
#'
#' @family SparkDataFrame functions
#' @rdname persist
@@ -634,9 +641,10 @@ setMethod("unpersist",
#' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s),
#' using `spark.sql.shuffle.partitions` as number of partitions.}
#'}
-#' @param x A SparkDataFrame
-#' @param numPartitions The number of partitions to use.
-#' @param col The column by which the partitioning will be performed.
+#' @param x a SparkDataFrame.
+#' @param numPartitions the number of partitions to use.
+#' @param col the column by which the partitioning will be performed.
+#' @param ... additional column(s) to be used in the partitioning.
#'
#' @family SparkDataFrame functions
#' @rdname repartition
@@ -915,8 +923,7 @@ setMethod("sample_frac",
#' Returns the number of rows in a SparkDataFrame
#'
-#' @param x A SparkDataFrame
-#'
+#' @param x a SparkDataFrame.
#' @family SparkDataFrame functions
#' @rdname nrow
#' @name count
@@ -1092,8 +1099,10 @@ setMethod("limit",
dataFrame(res)
})
-#' Take the first NUM rows of a SparkDataFrame and return a the results as a R data.frame
+#' Take the first NUM rows of a SparkDataFrame and return the results as a R data.frame
#'
+#' @param x a SparkDataFrame.
+#' @param num number of rows to take.
#' @family SparkDataFrame functions
#' @rdname take
#' @name take
@@ -1120,9 +1129,9 @@ setMethod("take",
#' then head() returns the first 6 rows in keeping with the current data.frame
#' convention in R.
#'
-#' @param x A SparkDataFrame
-#' @param num The number of rows to return. Default is 6.
-#' @return A data.frame
+#' @param x a SparkDataFrame.
+#' @param num the number of rows to return. Default is 6.
+#' @return A data.frame.
#'
#' @family SparkDataFrame functions
#' @aliases head,SparkDataFrame-method
@@ -1146,7 +1155,7 @@ setMethod("head",
#' Return the first row of a SparkDataFrame
#'
-#' @param x A SparkDataFrame
+#' @param x a SparkDataFrame or a column used in aggregation function.
#'
#' @family SparkDataFrame functions
#' @aliases first,SparkDataFrame-method
@@ -1198,6 +1207,7 @@ setMethod("toRDD",
#' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them.
#'
#' @param x a SparkDataFrame
+#' @param ... variable(s) (character names(s) or Column(s)) to group on.
#' @return a GroupedData
#' @family SparkDataFrame functions
#' @aliases groupBy,SparkDataFrame-method
@@ -1240,7 +1250,6 @@ setMethod("group_by",
#'
#' Compute aggregates by specifying a list of columns
#'
-#' @param x a SparkDataFrame
#' @family SparkDataFrame functions
#' @aliases agg,SparkDataFrame-method
#' @rdname summarize
@@ -1387,16 +1396,15 @@ setMethod("dapplyCollect",
#' Groups the SparkDataFrame using the specified columns and applies the R function to each
#' group.
#'
-#' @param x A SparkDataFrame
-#' @param cols Grouping columns
-#' @param func A function to be applied to each group partition specified by grouping
+#' @param cols grouping columns.
+#' @param func a function to be applied to each group partition specified by grouping
#' column of the SparkDataFrame. The function `func` takes as argument
#' a key - grouping columns and a data frame - a local R data.frame.
#' The output of `func` is a local R data.frame.
-#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
+#' @param schema the schema of the resulting SparkDataFrame after the function is applied.
#' The schema must match to output of `func`. It has to be defined for each
#' output column with preferred output column name and corresponding data type.
-#' @return a SparkDataFrame
+#' @return A SparkDataFrame.
#' @family SparkDataFrame functions
#' @aliases gapply,SparkDataFrame-method
#' @rdname gapply
@@ -1479,13 +1487,12 @@ setMethod("gapply",
#' Groups the SparkDataFrame using the specified columns, applies the R function to each
#' group and collects the result back to R as data.frame.
#'
-#' @param x A SparkDataFrame
-#' @param cols Grouping columns
-#' @param func A function to be applied to each group partition specified by grouping
+#' @param cols grouping columns.
+#' @param func a function to be applied to each group partition specified by grouping
#' column of the SparkDataFrame. The function `func` takes as argument
#' a key - grouping columns and a data frame - a local R data.frame.
#' The output of `func` is a local R data.frame.
-#' @return a data.frame
+#' @return A data.frame.
#' @family SparkDataFrame functions
#' @aliases gapplyCollect,SparkDataFrame-method
#' @rdname gapplyCollect
@@ -1632,6 +1639,7 @@ getColumn <- function(x, c) {
column(callJMethod(x@sdf, "col", c))
}
+#' @param name name of a Column (without being wrapped by \code{""}).
#' @rdname select
#' @name $
#' @aliases $,SparkDataFrame-method
@@ -1641,6 +1649,7 @@ setMethod("$", signature(x = "SparkDataFrame"),
getColumn(x, name)
})
+#' @param value a Column or NULL. If NULL, the specified Column is dropped.
#' @rdname select
#' @name $<-
#' @aliases $<-,SparkDataFrame-method
@@ -1715,12 +1724,13 @@ setMethod("[", signature(x = "SparkDataFrame"),
#' Subset
#'
#' Return subsets of SparkDataFrame according to given conditions
-#' @param x A SparkDataFrame
-#' @param subset (Optional) A logical expression to filter on rows
-#' @param select expression for the single Column or a list of columns to select from the SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param i,subset (Optional) a logical expression to filter on rows.
+#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame.
#' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column.
-#' Otherwise, a SparkDataFrame will always be returned.
-#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns
+#' Otherwise, a SparkDataFrame will always be returned.
+#' @param ... currently not used.
+#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns.
#' @export
#' @family SparkDataFrame functions
#' @aliases subset,SparkDataFrame-method
@@ -1755,9 +1765,12 @@ setMethod("subset", signature(x = "SparkDataFrame"),
#' Select
#'
#' Selects a set of columns with names or Column expressions.
-#' @param x A SparkDataFrame
-#' @param col A list of columns or single Column or name
-#' @return A new SparkDataFrame with selected columns
+#' @param x a SparkDataFrame.
+#' @param col a list of columns or single Column or name.
+#' @param ... additional column(s) if only one column is specified in \code{col}.
+#' If more than one column is assigned in \code{col}, \code{...}
+#' should be left empty.
+#' @return A new SparkDataFrame with selected columns.
#' @export
#' @family SparkDataFrame functions
#' @rdname select
@@ -1854,9 +1867,9 @@ setMethod("selectExpr",
#' Return a new SparkDataFrame by adding a column or replacing the existing column
#' that has the same name.
#'
-#' @param x A SparkDataFrame
-#' @param colName A column name.
-#' @param col A Column expression.
+#' @param x a SparkDataFrame.
+#' @param colName a column name.
+#' @param col a Column expression.
#' @return A SparkDataFrame with the new column added or the existing column replaced.
#' @family SparkDataFrame functions
#' @aliases withColumn,SparkDataFrame,character,Column-method
@@ -1885,8 +1898,8 @@ setMethod("withColumn",
#'
#' Return a new SparkDataFrame with the specified columns added or replaced.
#'
-#' @param .data A SparkDataFrame
-#' @param col a named argument of the form name = col
+#' @param .data a SparkDataFrame.
+#' @param ... additional column argument(s) each in the form name = col.
#' @return A new SparkDataFrame with the new columns added or replaced.
#' @family SparkDataFrame functions
#' @aliases mutate,SparkDataFrame-method
@@ -1963,6 +1976,7 @@ setMethod("mutate",
do.call(select, c(x, colList, deDupCols))
})
+#' @param _data a SparkDataFrame.
#' @export
#' @rdname mutate
#' @aliases transform,SparkDataFrame-method
@@ -2044,14 +2058,14 @@ setMethod("rename",
setClassUnion("characterOrColumn", c("character", "Column"))
-#' Arrange
+#' Arrange Rows by Variables
#'
#' Sort a SparkDataFrame by the specified column(s).
#'
-#' @param x A SparkDataFrame to be sorted.
-#' @param col A character or Column object vector indicating the fields to sort on
-#' @param ... Additional sorting fields
-#' @param decreasing A logical argument indicating sorting order for columns when
+#' @param x a SparkDataFrame to be sorted.
+#' @param col a character or Column object indicating the fields to sort on
+#' @param ... additional sorting fields
+#' @param decreasing a logical argument indicating sorting order for columns when
#' a character vector is specified for col
#' @return A SparkDataFrame where all elements are sorted.
#' @family SparkDataFrame functions
@@ -2116,7 +2130,6 @@ setMethod("arrange",
})
#' @rdname arrange
-#' @name orderBy
#' @aliases orderBy,SparkDataFrame,characterOrColumn-method
#' @export
#' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0
@@ -2275,11 +2288,18 @@ setMethod("join",
#' specified, the common column names in \code{x} and \code{y} will be used.
#' @param by.x a character vector specifying the joining columns for x.
#' @param by.y a character vector specifying the joining columns for y.
+#' @param all a boolean value setting \code{all.x} and \code{all.y}
+#' if any of them are unset.
#' @param all.x a boolean value indicating whether all the rows in x should
#' be including in the join
#' @param all.y a boolean value indicating whether all the rows in y should
#' be including in the join
#' @param sort a logical argument indicating whether the resulting columns should be sorted
+#' @param suffixes a string vector of length 2 used to make colnames of
+#' \code{x} and \code{y} unique.
+#' The first element is appended to each colname of \code{x}.
+#' The second element is appended to each colname of \code{y}.
+#' @param ... additional argument(s) passed to the method.
#' @details If all.x and all.y are set to FALSE, a natural join will be returned. If
#' all.x is set to TRUE and all.y is set to FALSE, a left outer join will
#' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right
@@ -2308,7 +2328,7 @@ setMethod("merge",
signature(x = "SparkDataFrame", y = "SparkDataFrame"),
function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by,
all = FALSE, all.x = all, all.y = all,
- sort = TRUE, suffixes = c("_x", "_y"), ... ) {
+ sort = TRUE, suffixes = c("_x", "_y"), ...) {
if (length(suffixes) != 2) {
stop("suffixes must have length 2")
@@ -2461,8 +2481,9 @@ setMethod("unionAll",
#' Union two or more SparkDataFrames. This is equivalent to `UNION ALL` in SQL.
#' Note that this does not remove duplicate rows across the two SparkDataFrames.
#'
-#' @param x A SparkDataFrame
-#' @param ... Additional SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param ... additional SparkDataFrame(s).
+#' @param deparse.level dummy variable, currently not used.
#' @return A SparkDataFrame containing the result of the union.
#' @family SparkDataFrame functions
#' @aliases rbind,SparkDataFrame-method
@@ -2519,8 +2540,8 @@ setMethod("intersect",
#' Return a new SparkDataFrame containing rows in this SparkDataFrame
#' but not in another SparkDataFrame. This is equivalent to `EXCEPT` in SQL.
#'
-#' @param x A SparkDataFrame
-#' @param y A SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param y a SparkDataFrame.
#' @return A SparkDataFrame containing the result of the except operation.
#' @family SparkDataFrame functions
#' @aliases except,SparkDataFrame,SparkDataFrame-method
@@ -2561,10 +2582,11 @@ setMethod("except",
#' and to not change the existing data.
#' }
#'
-#' @param df A SparkDataFrame
-#' @param path A name for the table
-#' @param source A name for external data source
-#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param df a SparkDataFrame.
+#' @param path a name for the table.
+#' @param source a name for external data source.
+#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param ... additional argument(s) passed to the method.
#'
#' @family SparkDataFrame functions
#' @aliases write.df,SparkDataFrame,character-method
@@ -2623,10 +2645,11 @@ setMethod("saveDF",
#' ignore: The save operation is expected to not save the contents of the SparkDataFrame
#' and to not change the existing data. \cr
#'
-#' @param df A SparkDataFrame
-#' @param tableName A name for the table
-#' @param source A name for external data source
-#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @param df a SparkDataFrame.
+#' @param tableName a name for the table.
+#' @param source a name for external data source.
+#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default).
+#' @param ... additional option(s) passed to the method.
#'
#' @family SparkDataFrame functions
#' @aliases saveAsTable,SparkDataFrame,character-method
@@ -2662,10 +2685,10 @@ setMethod("saveAsTable",
#' Computes statistics for numeric columns.
#' If no columns are given, this function computes statistics for all numerical columns.
#'
-#' @param x A SparkDataFrame to be computed.
-#' @param col A string of name
-#' @param ... Additional expressions
-#' @return A SparkDataFrame
+#' @param x a SparkDataFrame to be computed.
+#' @param col a string of name.
+#' @param ... additional expressions.
+#' @return A SparkDataFrame.
#' @family SparkDataFrame functions
#' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method
#' @rdname summary
@@ -2700,6 +2723,7 @@ setMethod("describe",
dataFrame(sdf)
})
+#' @param object a SparkDataFrame to be summarized.
#' @rdname summary
#' @name summary
#' @aliases summary,SparkDataFrame-method
@@ -2715,16 +2739,20 @@ setMethod("summary",
#'
#' dropna, na.omit - Returns a new SparkDataFrame omitting rows with null values.
#'
-#' @param x A SparkDataFrame.
+#' @param x a SparkDataFrame.
#' @param how "any" or "all".
#' if "any", drop a row if it contains any nulls.
#' if "all", drop a row only if all its values are null.
#' if minNonNulls is specified, how is ignored.
-#' @param minNonNulls If specified, drop rows that have less than
+#' @param minNonNulls if specified, drop rows that have less than
#' minNonNulls non-null values.
#' This overwrites the how parameter.
-#' @param cols Optional list of column names to consider.
-#' @return A SparkDataFrame
+#' @param cols optional list of column names to consider. In `fillna`,
+#' columns specified in cols that do not have matching data
+#' type are ignored. For example, if value is a character, and
+#' subset contains a non-character column, then the non-character
+#' column is simply ignored.
+#' @return A SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname nafunctions
@@ -2756,6 +2784,8 @@ setMethod("dropna",
dataFrame(sdf)
})
+#' @param object a SparkDataFrame.
+#' @param ... further arguments to be passed to or from other methods.
#' @rdname nafunctions
#' @name na.omit
#' @aliases na.omit,SparkDataFrame-method
@@ -2763,24 +2793,18 @@ setMethod("dropna",
#' @note na.omit since 1.5.0
setMethod("na.omit",
signature(object = "SparkDataFrame"),
- function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL, ...) {
dropna(object, how, minNonNulls, cols)
})
#' fillna - Replace null values.
#'
-#' @param x A SparkDataFrame.
-#' @param value Value to replace null values with.
+#' @param value value to replace null values with.
#' Should be an integer, numeric, character or named list.
#' If the value is a named list, then cols is ignored and
#' value must be a mapping from column name (character) to
#' replacement value. The replacement value must be an
#' integer, numeric or character.
-#' @param cols optional list of column names to consider.
-#' Columns specified in cols that do not have matching data
-#' type are ignored. For example, if value is a character, and
-#' subset contains a non-character column, then the non-character
-#' column is simply ignored.
#'
#' @rdname nafunctions
#' @name fillna
@@ -2845,8 +2869,11 @@ setMethod("fillna",
#' Since data.frames are held in memory, ensure that you have enough memory
#' in your system to accommodate the contents.
#'
-#' @param x a SparkDataFrame
-#' @return a data.frame
+#' @param x a SparkDataFrame.
+#' @param row.names NULL or a character vector giving the row names for the data frame.
+#' @param optional If `TRUE`, converting column names is optional.
+#' @param ... additional arguments to pass to base::as.data.frame.
+#' @return A data.frame.
#' @family SparkDataFrame functions
#' @aliases as.data.frame,SparkDataFrame-method
#' @rdname as.data.frame
@@ -3000,9 +3027,10 @@ setMethod("str",
#' Returns a new SparkDataFrame with columns dropped.
#' This is a no-op if schema doesn't contain column name(s).
#'
-#' @param x A SparkDataFrame.
-#' @param cols A character vector of column names or a Column.
-#' @return A SparkDataFrame
+#' @param x a SparkDataFrame.
+#' @param ... further arguments to be passed to or from other methods.
+#' @param col a character vector of column names or a Column.
+#' @return A SparkDataFrame.
#'
#' @family SparkDataFrame functions
#' @rdname drop
@@ -3049,8 +3077,8 @@ setMethod("drop",
#'
#' @name histogram
#' @param nbins the number of bins (optional). Default value is 10.
+#' @param col the column as Character string or a Column to build the histogram from.
#' @param df the SparkDataFrame containing the Column to build the histogram from.
-#' @param colname the name of the column to build the histogram from.
#' @return a data.frame with the histogram statistics, i.e., counts and centroids.
#' @rdname histogram
#' @aliases histogram,SparkDataFrame,characterOrColumn-method
@@ -3184,6 +3212,7 @@ setMethod("histogram",
#' @param x A SparkDataFrame
#' @param url JDBC database url of the form `jdbc:subprotocol:subname`
#' @param tableName The name of the table in the external database
+#' @param ... additional JDBC database connection propertie(s).
#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
#' @family SparkDataFrame functions
#' @rdname write.jdbc
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 72a805256523e..6b254bb0d302c 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -67,7 +67,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode,
.Object
})
-setMethod("show", "RDD",
+setMethod("showRDD", "RDD",
function(object) {
cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep = ""))
})
@@ -215,7 +215,7 @@ setValidity("RDD",
#' @rdname cache-methods
#' @aliases cache,RDD-method
#' @noRd
-setMethod("cache",
+setMethod("cacheRDD",
signature(x = "RDD"),
function(x) {
callJMethod(getJRDD(x), "cache")
@@ -235,12 +235,12 @@ setMethod("cache",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10, 2L)
-#' persist(rdd, "MEMORY_AND_DISK")
+#' persistRDD(rdd, "MEMORY_AND_DISK")
#'}
#' @rdname persist
#' @aliases persist,RDD-method
#' @noRd
-setMethod("persist",
+setMethod("persistRDD",
signature(x = "RDD", newLevel = "character"),
function(x, newLevel = "MEMORY_ONLY") {
callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel))
@@ -259,12 +259,12 @@ setMethod("persist",
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10, 2L)
#' cache(rdd) # rdd@@env$isCached == TRUE
-#' unpersist(rdd) # rdd@@env$isCached == FALSE
+#' unpersistRDD(rdd) # rdd@@env$isCached == FALSE
#'}
#' @rdname unpersist-methods
#' @aliases unpersist,RDD-method
#' @noRd
-setMethod("unpersist",
+setMethod("unpersistRDD",
signature(x = "RDD"),
function(x) {
callJMethod(getJRDD(x), "unpersist")
@@ -345,13 +345,13 @@ setMethod("numPartitions",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10, 2L)
-#' collect(rdd) # list from 1 to 10
+#' collectRDD(rdd) # list from 1 to 10
#' collectPartition(rdd, 0L) # list from 1 to 5
#'}
#' @rdname collect-methods
#' @aliases collect,RDD-method
#' @noRd
-setMethod("collect",
+setMethod("collectRDD",
signature(x = "RDD"),
function(x, flatten = TRUE) {
# Assumes a pairwise RDD is backed by a JavaPairRDD.
@@ -397,7 +397,7 @@ setMethod("collectPartition",
setMethod("collectAsMap",
signature(x = "RDD"),
function(x) {
- pairList <- collect(x)
+ pairList <- collectRDD(x)
map <- new.env()
lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) })
as.list(map)
@@ -411,30 +411,30 @@ setMethod("collectAsMap",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
-#' count(rdd) # 10
+#' countRDD(rdd) # 10
#' length(rdd) # Same as count
#'}
#' @rdname count
#' @aliases count,RDD-method
#' @noRd
-setMethod("count",
+setMethod("countRDD",
signature(x = "RDD"),
function(x) {
countPartition <- function(part) {
as.integer(length(part))
}
valsRDD <- lapplyPartition(x, countPartition)
- vals <- collect(valsRDD)
+ vals <- collectRDD(valsRDD)
sum(as.integer(vals))
})
#' Return the number of elements in the RDD
#' @rdname count
#' @noRd
-setMethod("length",
+setMethod("lengthRDD",
signature(x = "RDD"),
function(x) {
- count(x)
+ countRDD(x)
})
#' Return the count of each unique value in this RDD as a list of
@@ -460,7 +460,7 @@ setMethod("countByValue",
signature(x = "RDD"),
function(x) {
ones <- lapply(x, function(item) { list(item, 1L) })
- collect(reduceByKey(ones, `+`, getNumPartitions(x)))
+ collectRDD(reduceByKey(ones, `+`, getNumPartitions(x)))
})
#' Apply a function to all elements
@@ -479,7 +479,7 @@ setMethod("countByValue",
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 })
-#' collect(multiplyByTwo) # 2,4,6...
+#' collectRDD(multiplyByTwo) # 2,4,6...
#'}
setMethod("lapply",
signature(X = "RDD", FUN = "function"),
@@ -512,7 +512,7 @@ setMethod("map",
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) })
-#' collect(multiplyByTwo) # 2,20,4,40,6,60...
+#' collectRDD(multiplyByTwo) # 2,20,4,40,6,60...
#'}
#' @rdname flatMap
#' @aliases flatMap,RDD,function-method
@@ -541,7 +541,7 @@ setMethod("flatMap",
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) })
-#' collect(partitionSum) # 15, 40
+#' collectRDD(partitionSum) # 15, 40
#'}
#' @rdname lapplyPartition
#' @aliases lapplyPartition,RDD,function-method
@@ -576,7 +576,7 @@ setMethod("mapPartitions",
#' rdd <- parallelize(sc, 1:10, 5L)
#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) {
#' partIndex * Reduce("+", part) })
-#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76
+#' collectRDD(prod, flatten = FALSE) # 0, 7, 22, 45, 76
#'}
#' @rdname lapplyPartitionsWithIndex
#' @aliases lapplyPartitionsWithIndex,RDD,function-method
@@ -607,7 +607,7 @@ setMethod("mapPartitionsWithIndex",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
-#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2)
+#' unlist(collectRDD(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2)
#'}
# nolint end
#' @rdname filterRDD
@@ -656,7 +656,7 @@ setMethod("reduce",
Reduce(func, part)
}
- partitionList <- collect(lapplyPartition(x, reducePartition),
+ partitionList <- collectRDD(lapplyPartition(x, reducePartition),
flatten = FALSE)
Reduce(func, partitionList)
})
@@ -736,7 +736,7 @@ setMethod("foreach",
lapply(x, func)
NULL
}
- invisible(collect(mapPartitions(x, partition.func)))
+ invisible(collectRDD(mapPartitions(x, partition.func)))
})
#' Applies a function to each partition in an RDD, and forces evaluation.
@@ -753,7 +753,7 @@ setMethod("foreach",
setMethod("foreachPartition",
signature(x = "RDD", func = "function"),
function(x, func) {
- invisible(collect(mapPartitions(x, func)))
+ invisible(collectRDD(mapPartitions(x, func)))
})
#' Take elements from an RDD.
@@ -768,13 +768,13 @@ setMethod("foreachPartition",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
-#' take(rdd, 2L) # list(1, 2)
+#' takeRDD(rdd, 2L) # list(1, 2)
#'}
# nolint end
#' @rdname take
#' @aliases take,RDD,numeric-method
#' @noRd
-setMethod("take",
+setMethod("takeRDD",
signature(x = "RDD", num = "numeric"),
function(x, num) {
resList <- list()
@@ -817,13 +817,13 @@ setMethod("take",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
-#' first(rdd)
+#' firstRDD(rdd)
#' }
#' @noRd
-setMethod("first",
+setMethod("firstRDD",
signature(x = "RDD"),
function(x) {
- take(x, 1)[[1]]
+ takeRDD(x, 1)[[1]]
})
#' Removes the duplicates from RDD.
@@ -838,13 +838,13 @@ setMethod("first",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, c(1,2,2,3,3,3))
-#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3)
+#' sort(unlist(collectRDD(distinctRDD(rdd)))) # c(1, 2, 3)
#'}
# nolint end
#' @rdname distinct
#' @aliases distinct,RDD-method
#' @noRd
-setMethod("distinct",
+setMethod("distinctRDD",
signature(x = "RDD"),
function(x, numPartitions = SparkR:::getNumPartitions(x)) {
identical.mapped <- lapply(x, function(x) { list(x, NULL) })
@@ -868,8 +868,8 @@ setMethod("distinct",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
-#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements
-#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates
+#' collectRDD(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements
+#' collectRDD(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates
#'}
#' @rdname sampleRDD
#' @aliases sampleRDD,RDD
@@ -942,7 +942,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
fraction <- 0.0
total <- 0
multiplier <- 3.0
- initialCount <- count(x)
+ initialCount <- countRDD(x)
maxSelected <- 0
MAXINT <- .Machine$integer.max
@@ -964,7 +964,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
}
set.seed(seed)
- samples <- collect(sampleRDD(x, withReplacement, fraction,
+ samples <- collectRDD(sampleRDD(x, withReplacement, fraction,
as.integer(ceiling(runif(1,
-MAXINT,
MAXINT)))))
@@ -972,7 +972,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
# take samples; this shouldn't happen often because we use a big
# multiplier for thei initial size
while (length(samples) < total)
- samples <- collect(sampleRDD(x, withReplacement, fraction,
+ samples <- collectRDD(sampleRDD(x, withReplacement, fraction,
as.integer(ceiling(runif(1,
-MAXINT,
MAXINT)))))
@@ -990,7 +990,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(1, 2, 3))
-#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3))
+#' collectRDD(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3))
#'}
# nolint end
#' @rdname keyBy
@@ -1019,12 +1019,12 @@ setMethod("keyBy",
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L)
#' getNumPartitions(rdd) # 4
-#' getNumPartitions(repartition(rdd, 2L)) # 2
+#' getNumPartitions(repartitionRDD(rdd, 2L)) # 2
#'}
#' @rdname repartition
#' @aliases repartition,RDD
#' @noRd
-setMethod("repartition",
+setMethod("repartitionRDD",
signature(x = "RDD"),
function(x, numPartitions) {
if (!is.null(numPartitions) && is.numeric(numPartitions)) {
@@ -1064,7 +1064,7 @@ setMethod("coalesce",
})
}
shuffled <- lapplyPartitionsWithIndex(x, func)
- repartitioned <- partitionBy(shuffled, numPartitions)
+ repartitioned <- partitionByRDD(shuffled, numPartitions)
values(repartitioned)
} else {
jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle)
@@ -1135,7 +1135,7 @@ setMethod("saveAsTextFile",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(3, 2, 1))
-#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3)
+#' collectRDD(sortBy(rdd, function(x) { x })) # list (1, 2, 3)
#'}
# nolint end
#' @rdname sortBy
@@ -1304,7 +1304,7 @@ setMethod("aggregateRDD",
Reduce(seqOp, part, zeroValue)
}
- partitionList <- collect(lapplyPartition(x, partitionFunc),
+ partitionList <- collectRDD(lapplyPartition(x, partitionFunc),
flatten = FALSE)
Reduce(combOp, partitionList, zeroValue)
})
@@ -1322,7 +1322,7 @@ setMethod("aggregateRDD",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
-#' collect(pipeRDD(rdd, "more")
+#' pipeRDD(rdd, "more")
#' Output: c("1", "2", ..., "10")
#'}
#' @aliases pipeRDD,RDD,character-method
@@ -1397,7 +1397,7 @@ setMethod("setName",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
-#' collect(zipWithUniqueId(rdd))
+#' collectRDD(zipWithUniqueId(rdd))
#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2))
#'}
# nolint end
@@ -1440,7 +1440,7 @@ setMethod("zipWithUniqueId",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
-#' collect(zipWithIndex(rdd))
+#' collectRDD(zipWithIndex(rdd))
#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4))
#'}
# nolint end
@@ -1452,7 +1452,7 @@ setMethod("zipWithIndex",
function(x) {
n <- getNumPartitions(x)
if (n > 1) {
- nums <- collect(lapplyPartition(x,
+ nums <- collectRDD(lapplyPartition(x,
function(part) {
list(length(part))
}))
@@ -1488,7 +1488,7 @@ setMethod("zipWithIndex",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, as.list(1:4), 2L)
-#' collect(glom(rdd))
+#' collectRDD(glom(rdd))
#' # list(list(1, 2), list(3, 4))
#'}
# nolint end
@@ -1556,7 +1556,7 @@ setMethod("unionRDD",
#' sc <- sparkR.init()
#' rdd1 <- parallelize(sc, 0:4)
#' rdd2 <- parallelize(sc, 1000:1004)
-#' collect(zipRDD(rdd1, rdd2))
+#' collectRDD(zipRDD(rdd1, rdd2))
#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))
#'}
# nolint end
@@ -1628,7 +1628,7 @@ setMethod("cartesian",
#' sc <- sparkR.init()
#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4))
#' rdd2 <- parallelize(sc, list(2, 4))
-#' collect(subtract(rdd1, rdd2))
+#' collectRDD(subtract(rdd1, rdd2))
#' # list(1, 1, 3)
#'}
# nolint end
@@ -1662,7 +1662,7 @@ setMethod("subtract",
#' sc <- sparkR.init()
#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5))
#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8))
-#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x }))
+#' collectRDD(sortBy(intersection(rdd1, rdd2), function(x) { x }))
#' # list(1, 2, 3)
#'}
# nolint end
@@ -1699,7 +1699,7 @@ setMethod("intersection",
#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2
#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4
#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6
-#' collect(zipPartitions(rdd1, rdd2, rdd3,
+#' collectRDD(zipPartitions(rdd1, rdd2, rdd3,
#' func = function(x, y, z) { list(list(x, y, z))} ))
#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))
#'}
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index a14bcd91b3eac..b9a7f2b551f7a 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -165,9 +165,9 @@ getDefaultSqlSource <- function() {
#'
#' Converts R data.frame or list into SparkDataFrame.
#'
-#' @param data An RDD or list or data.frame
-#' @param schema a list of column names or named list (StructType), optional
-#' @return a SparkDataFrame
+#' @param data an RDD or list or data.frame.
+#' @param schema a list of column names or named list (StructType), optional.
+#' @return A SparkDataFrame.
#' @rdname createDataFrame
#' @export
#' @examples
@@ -181,7 +181,7 @@ getDefaultSqlSource <- function() {
#' @method createDataFrame default
#' @note createDataFrame since 1.4.0
# TODO(davies): support sampling and infer type from NA
-createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
+createDataFrame.default <- function(data, schema = NULL) {
sparkSession <- getSparkSession()
if (is.data.frame(data)) {
# get the names of columns, they will be put into RDD
@@ -218,7 +218,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
}
if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) {
- row <- first(rdd)
+ row <- firstRDD(rdd)
names <- if (is.null(schema)) {
names(row)
} else {
@@ -257,23 +257,25 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
}
createDataFrame <- function(x, ...) {
- dispatchFunc("createDataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...)
+ dispatchFunc("createDataFrame(data, schema = NULL)", x, ...)
}
+#' @param samplingRatio Currently not used.
#' @rdname createDataFrame
#' @aliases createDataFrame
#' @export
#' @method as.DataFrame default
#' @note as.DataFrame since 1.6.0
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
- createDataFrame(data, schema, samplingRatio)
+ createDataFrame(data, schema)
}
+#' @param ... additional argument(s).
#' @rdname createDataFrame
#' @aliases as.DataFrame
#' @export
-as.DataFrame <- function(x, ...) {
- dispatchFunc("as.DataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...)
+as.DataFrame <- function(data, ...) {
+ dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...)
}
#' toDF
@@ -398,7 +400,7 @@ read.orc <- function(path) {
#'
#' Loads a Parquet file, returning the result as a SparkDataFrame.
#'
-#' @param path Path of file to read. A vector of multiple paths is allowed.
+#' @param path path of file to read. A vector of multiple paths is allowed.
#' @return SparkDataFrame
#' @rdname read.parquet
#' @export
@@ -418,6 +420,7 @@ read.parquet <- function(x, ...) {
dispatchFunc("read.parquet(...)", x, ...)
}
+#' @param ... argument(s) passed to the method.
#' @rdname read.parquet
#' @name parquetFile
#' @export
@@ -727,6 +730,7 @@ dropTempView <- function(viewName) {
#' @param source The name of external data source
#' @param schema The data schema defined in structType
#' @param na.strings Default string value for NA when source is "csv"
+#' @param ... additional external data source specific named propertie(s).
#' @return SparkDataFrame
#' @rdname read.df
#' @name read.df
@@ -791,10 +795,11 @@ loadDF <- function(x, ...) {
#' If `source` is not specified, the default data source configured by
#' "spark.sql.sources.default" will be used.
#'
-#' @param tableName A name of the table
-#' @param path The path of files to load
-#' @param source the name of external data source
-#' @return SparkDataFrame
+#' @param tableName a name of the table.
+#' @param path the path of files to load.
+#' @param source the name of external data source.
+#' @param ... additional argument(s) passed to the method.
+#' @return A SparkDataFrame.
#' @rdname createExternalTable
#' @export
#' @examples
@@ -840,6 +845,7 @@ createExternalTable <- function(x, ...) {
#' clause expressions used to split the column `partitionColumn` evenly.
#' This defaults to SparkContext.defaultParallelism when unset.
#' @param predicates a list of conditions in the where clause; each one defines one partition
+#' @param ... additional JDBC database connection named propertie(s).
#' @return SparkDataFrame
#' @rdname read.jdbc
#' @name read.jdbc
diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R
index 4746380096245..07618bf84ca5d 100644
--- a/R/pkg/R/WindowSpec.R
+++ b/R/pkg/R/WindowSpec.R
@@ -54,8 +54,10 @@ setMethod("show", "WindowSpec",
#'
#' Defines the partitioning columns in a WindowSpec.
#'
-#' @param x a WindowSpec
-#' @return a WindowSpec
+#' @param x a WindowSpec.
+#' @param col a column to partition on (desribed by the name or Column object).
+#' @param ... additional column(s) to partition on.
+#' @return A WindowSpec.
#' @rdname partitionBy
#' @name partitionBy
#' @aliases partitionBy,WindowSpec-method
@@ -82,16 +84,18 @@ setMethod("partitionBy",
}
})
-#' orderBy
+#' Ordering Columns in a WindowSpec
#'
#' Defines the ordering columns in a WindowSpec.
-#'
#' @param x a WindowSpec
-#' @return a WindowSpec
-#' @rdname arrange
+#' @param col a character or Column object indicating an ordering column
+#' @param ... additional sorting fields
+#' @return A WindowSpec.
#' @name orderBy
+#' @rdname orderBy
#' @aliases orderBy,WindowSpec,character-method
#' @family windowspec_method
+#' @seealso See \link{arrange} for use in sorting a SparkDataFrame
#' @export
#' @examples
#' \dontrun{
@@ -105,7 +109,7 @@ setMethod("orderBy",
windowSpec(callJMethod(x@sws, "orderBy", col, list(...)))
})
-#' @rdname arrange
+#' @rdname orderBy
#' @name orderBy
#' @aliases orderBy,WindowSpec,Column-method
#' @export
@@ -122,7 +126,7 @@ setMethod("orderBy",
#' rowsBetween
#'
#' Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
-#'
+#'
#' Both `start` and `end` are relative positions from the current row. For example, "0" means
#' "current row", while "-1" means the row before the current row, and "5" means the fifth row
#' after the current row.
@@ -154,7 +158,7 @@ setMethod("rowsBetween",
#' rangeBetween
#'
#' Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
-#'
+#'
#' Both `start` and `end` are relative from the current row. For example, "0" means "current row",
#' while "-1" means one off before the current row, and "5" means the five off after the
#' current row.
@@ -188,8 +192,11 @@ setMethod("rangeBetween",
#' over
#'
-#' Define a windowing column.
+#' Define a windowing column.
#'
+#' @param x a Column object, usually one returned by window function(s).
+#' @param window a WindowSpec object. Can be created by `windowPartitionBy` or
+#' `windowOrderBy` and configured by other WindowSpec methods.
#' @rdname over
#' @name over
#' @aliases over,Column,WindowSpec-method
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 0edb9d2ae5c45..8329c0713d21e 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -163,8 +163,9 @@ setMethod("alias",
#' @family colum_func
#' @aliases substr,Column-method
#'
-#' @param start starting position
-#' @param stop ending position
+#' @param x a Column object.
+#' @param start starting position.
+#' @param stop ending position.
#' @note substr since 1.4.0
setMethod("substr", signature(x = "Column"),
function(x, start, stop) {
@@ -219,6 +220,7 @@ setMethod("endsWith", signature(x = "Column"),
#' @family colum_func
#' @aliases between,Column-method
#'
+#' @param x a Column object
#' @param bounds lower and upper bounds
#' @note between since 1.5.0
setMethod("between", signature(x = "Column"),
@@ -233,6 +235,11 @@ setMethod("between", signature(x = "Column"),
#' Casts the column to a different data type.
#'
+#' @param x a Column object.
+#' @param dataType a character object describing the target data type.
+#' See
+#' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{
+#' Spark Data Types} for available data types.
#' @rdname cast
#' @name cast
#' @family colum_func
@@ -254,10 +261,12 @@ setMethod("cast",
#' Match a column with given values.
#'
+#' @param x a Column.
+#' @param table a collection of values (coercible to list) to compare with.
#' @rdname match
#' @name %in%
#' @aliases %in%,Column-method
-#' @return a matched values as a result of comparing with given values.
+#' @return A matched values as a result of comparing with given values.
#' @export
#' @examples
#' \dontrun{
@@ -277,6 +286,9 @@ setMethod("%in%",
#' If values in the specified column are null, returns the value.
#' Can be used in conjunction with `when` to specify a default value for expressions.
#'
+#' @param x a Column.
+#' @param value value to replace when the corresponding entry in \code{x} is NA.
+#' Can be a single value or a Column.
#' @rdname otherwise
#' @name otherwise
#' @family colum_func
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 2538bb25073e1..13ade49eabfa6 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -267,7 +267,7 @@ spark.lapply <- function(list, func) {
sc <- getSparkContext()
rdd <- parallelize(sc, list, length(list))
results <- map(rdd, func)
- local <- collect(results)
+ local <- collectRDD(results)
local
}
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 573c915a5c67a..1f9115f278071 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -23,6 +23,7 @@ NULL
#' A new \linkS4class{Column} is created to represent the literal value.
#' If the parameter is a \linkS4class{Column}, it is returned unchanged.
#'
+#' @param x a literal value or a Column.
#' @family normal_funcs
#' @rdname lit
#' @name lit
@@ -89,8 +90,6 @@ setMethod("acos",
#' Returns the approximate number of distinct items in a group. This is a column
#' aggregate function.
#'
-#' @param x Column to compute on.
-#'
#' @rdname approxCountDistinct
#' @name approxCountDistinct
#' @return the approximate number of distinct items in a group.
@@ -171,8 +170,6 @@ setMethod("atan",
#'
#' Aggregate function: returns the average of the values in a group.
#'
-#' @param x Column to compute on.
-#'
#' @rdname avg
#' @name avg
#' @family agg_funcs
@@ -319,7 +316,7 @@ setMethod("column",
#'
#' Computes the Pearson Correlation Coefficient for two Columns.
#'
-#' @param x Column to compute on.
+#' @param col2 a (second) Column object.
#'
#' @rdname corr
#' @name corr
@@ -339,8 +336,6 @@ setMethod("corr", signature(x = "Column"),
#'
#' Compute the sample covariance between two expressions.
#'
-#' @param x Column to compute on.
-#'
#' @rdname cov
#' @name cov
#' @family math_funcs
@@ -362,8 +357,8 @@ setMethod("cov", signature(x = "characterOrColumn"),
#' @rdname cov
#'
-#' @param col1 First column to compute cov_samp.
-#' @param col2 Second column to compute cov_samp.
+#' @param col1 the first Column object.
+#' @param col2 the second Column object.
#' @name covar_samp
#' @aliases covar_samp,characterOrColumn,characterOrColumn-method
#' @note covar_samp since 2.0.0
@@ -451,10 +446,8 @@ setMethod("cosh",
#'
#' Returns the number of items in a group. This is a column aggregate function.
#'
-#' @param x Column to compute on.
-#'
-#' @rdname nrow
-#' @name count
+#' @rdname n
+#' @name n
#' @family agg_funcs
#' @aliases count,Column-method
#' @export
@@ -493,6 +486,7 @@ setMethod("crc32",
#' Calculates the hash code of given columns, and returns the result as a int column.
#'
#' @param x Column to compute on.
+#' @param ... additional Column(s) to be included.
#'
#' @rdname hash
#' @name hash
@@ -663,7 +657,8 @@ setMethod("factorial",
#' The function by default returns the first values it sees. It will return the first non-missing
#' value it sees when na.rm is set to true. If all values are missing, then NA is returned.
#'
-#' @param x Column to compute on.
+#' @param na.rm a logical value indicating whether NA values should be stripped
+#' before the computation proceeds.
#'
#' @rdname first
#' @name first
@@ -832,7 +827,10 @@ setMethod("kurtosis",
#' The function by default returns the last values it sees. It will return the last non-missing
#' value it sees when na.rm is set to true. If all values are missing, then NA is returned.
#'
-#' @param x Column to compute on.
+#' @param x column to compute on.
+#' @param na.rm a logical value indicating whether NA values should be stripped
+#' before the computation proceeds.
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @rdname last
#' @name last
@@ -847,7 +845,7 @@ setMethod("kurtosis",
#' @note last since 1.4.0
setMethod("last",
signature(x = "characterOrColumn"),
- function(x, na.rm = FALSE) {
+ function(x, na.rm = FALSE, ...) {
col <- if (class(x) == "Column") {
x@jc
} else {
@@ -1143,7 +1141,7 @@ setMethod("minute",
#' @export
#' @examples \dontrun{select(df, monotonically_increasing_id())}
setMethod("monotonically_increasing_id",
- signature(x = "missing"),
+ signature(),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "monotonically_increasing_id")
column(jc)
@@ -1273,12 +1271,15 @@ setMethod("round",
#' bround
#'
#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding
-#' mode if `scale` >= 0 or at integral part when `scale` < 0.
+#' mode if `scale` >= 0 or at integer part when `scale` < 0.
#' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number.
#' bround(2.5, 0) = 2, bround(3.5, 0) = 4.
#'
#' @param x Column to compute on.
-#'
+#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0,
+#' the nearest even number when \code{scale} = 0, and `scale` digits to the left
+#' of the decimal point when \code{scale} < 0.
+#' @param ... further arguments to be passed to or from other methods.
#' @rdname bround
#' @name bround
#' @family math_funcs
@@ -1288,7 +1289,7 @@ setMethod("round",
#' @note bround since 2.0.0
setMethod("bround",
signature(x = "Column"),
- function(x, scale = 0) {
+ function(x, scale = 0, ...) {
jc <- callJStatic("org.apache.spark.sql.functions", "bround", x@jc, as.integer(scale))
column(jc)
})
@@ -1319,7 +1320,7 @@ setMethod("rtrim",
#' Aggregate function: alias for \link{stddev_samp}
#'
#' @param x Column to compute on.
-#'
+#' @param na.rm currently not used.
#' @rdname sd
#' @name sd
#' @family agg_funcs
@@ -1335,7 +1336,7 @@ setMethod("rtrim",
#' @note sd since 1.6.0
setMethod("sd",
signature(x = "Column"),
- function(x) {
+ function(x, na.rm) {
# In R, sample standard deviation is calculated with the sd() function.
stddev_samp(x)
})
@@ -1497,7 +1498,7 @@ setMethod("soundex",
#' \dontrun{select(df, spark_partition_id())}
#' @note spark_partition_id since 2.0.0
setMethod("spark_partition_id",
- signature(x = "missing"),
+ signature(),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "spark_partition_id")
column(jc)
@@ -1560,7 +1561,8 @@ setMethod("stddev_samp",
#'
#' Creates a new struct column that composes multiple input columns.
#'
-#' @param x Column to compute on.
+#' @param x a column to compute on.
+#' @param ... optional column(s) to be included.
#'
#' @rdname struct
#' @name struct
@@ -1831,8 +1833,8 @@ setMethod("upper",
#'
#' Aggregate function: alias for \link{var_samp}.
#'
-#' @param x Column to compute on.
-#'
+#' @param x a Column to compute on.
+#' @param y,na.rm,use currently not used.
#' @rdname var
#' @name var
#' @family agg_funcs
@@ -1848,7 +1850,7 @@ setMethod("upper",
#' @note var since 1.6.0
setMethod("var",
signature(x = "Column"),
- function(x) {
+ function(x, y, na.rm, use) {
# In R, sample variance is calculated with the var() function.
var_samp(x)
})
@@ -2114,7 +2116,9 @@ setMethod("pmod", signature(y = "Column"),
#' @rdname approxCountDistinct
#' @name approxCountDistinct
#'
+#' @param x Column to compute on.
#' @param rsd maximum estimation error allowed (default = 0.05)
+#' @param ... further arguments to be passed to or from other methods.
#'
#' @aliases approxCountDistinct,Column-method
#' @export
@@ -2122,12 +2126,12 @@ setMethod("pmod", signature(y = "Column"),
#' @note approxCountDistinct(Column, numeric) since 1.4.0
setMethod("approxCountDistinct",
signature(x = "Column"),
- function(x, rsd = 0.05) {
+ function(x, rsd = 0.05, ...) {
jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd)
column(jc)
})
-#' Count Distinct
+#' Count Distinct Values
#'
#' @param x Column to compute on
#' @param ... other columns
@@ -2156,7 +2160,7 @@ setMethod("countDistinct",
#' concat
#'
#' Concatenates multiple input string columns together into a single string column.
-#'
+#'
#' @param x Column to compute on
#' @param ... other columns
#'
@@ -2246,7 +2250,6 @@ setMethod("ceiling",
})
#' @rdname sign
-#' @param x Column to compute on
#'
#' @name sign
#' @aliases sign,Column-method
@@ -2262,9 +2265,6 @@ setMethod("sign", signature(x = "Column"),
#'
#' Aggregate function: returns the number of distinct items in a group.
#'
-#' @param x Column to compute on
-#' @param ... other columns
-#'
#' @rdname countDistinct
#' @name n_distinct
#' @aliases n_distinct,Column-method
@@ -2276,9 +2276,8 @@ setMethod("n_distinct", signature(x = "Column"),
countDistinct(x, ...)
})
-#' @rdname nrow
-#' @param x Column to compute on
-#'
+#' @param x a Column.
+#' @rdname n
#' @name n
#' @aliases n,Column-method
#' @export
@@ -2300,8 +2299,8 @@ setMethod("n", signature(x = "Column"),
#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a
#' specialized implementation.
#'
-#' @param y Column to compute on
-#' @param x date format specification
+#' @param y Column to compute on.
+#' @param x date format specification.
#'
#' @family datetime_funcs
#' @rdname date_format
@@ -2320,8 +2319,8 @@ setMethod("date_format", signature(y = "Column", x = "character"),
#'
#' Assumes given timestamp is UTC and converts to given timezone.
#'
-#' @param y Column to compute on
-#' @param x time zone to use
+#' @param y Column to compute on.
+#' @param x time zone to use.
#'
#' @family datetime_funcs
#' @rdname from_utc_timestamp
@@ -2370,8 +2369,8 @@ setMethod("instr", signature(y = "Column", x = "character"),
#' Day of the week parameter is case insensitive, and accepts first three or two characters:
#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
#'
-#' @param y Column to compute on
-#' @param x Day of the week string
+#' @param y Column to compute on.
+#' @param x Day of the week string.
#'
#' @family datetime_funcs
#' @rdname next_day
@@ -2637,6 +2636,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri
#' Parses the expression string into the column that it represents, similar to
#' SparkDataFrame.selectExpr
#'
+#' @param x an expression character object to be parsed.
#' @family normal_funcs
#' @rdname expr
#' @aliases expr,character-method
@@ -2654,6 +2654,9 @@ setMethod("expr", signature(x = "character"),
#'
#' Formats the arguments in printf-style and returns the result as a string column.
#'
+#' @param format a character object of format strings.
+#' @param x a Column object.
+#' @param ... additional Column(s).
#' @family string_funcs
#' @rdname format_string
#' @name format_string
@@ -2676,6 +2679,11 @@ setMethod("format_string", signature(format = "character", x = "Column"),
#' representing the timestamp of that moment in the current system time zone in the given
#' format.
#'
+#' @param x a Column of unix timestamp.
+#' @param format the target format. See
+#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{
+#' Customizing Formats} for available options.
+#' @param ... further arguments to be passed to or from other methods.
#' @family datetime_funcs
#' @rdname from_unixtime
#' @name from_unixtime
@@ -2688,7 +2696,7 @@ setMethod("format_string", signature(format = "character", x = "Column"),
#'}
#' @note from_unixtime since 1.5.0
setMethod("from_unixtime", signature(x = "Column"),
- function(x, format = "yyyy-MM-dd HH:mm:ss") {
+ function(x, format = "yyyy-MM-dd HH:mm:ss", ...) {
jc <- callJStatic("org.apache.spark.sql.functions",
"from_unixtime",
x@jc, format)
@@ -2702,19 +2710,21 @@ setMethod("from_unixtime", signature(x = "Column"),
#' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
#' the order of months are not supported.
#'
-#' The time column must be of TimestampType.
-#'
-#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
-#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
-#' If the `slideDuration` is not provided, the windows will be tumbling windows.
-#'
-#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start
-#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes
-#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`.
-#'
-#' The output column will be a struct called 'window' by default with the nested columns 'start'
-#' and 'end'.
-#'
+#' @param x a time Column. Must be of TimestampType.
+#' @param windowDuration a string specifying the width of the window, e.g. '1 second',
+#' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week',
+#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
+#' @param slideDuration a string specifying the sliding interval of the window. Same format as
+#' \code{windowDuration}. A new window will be generated every
+#' \code{slideDuration}. Must be less than or equal to
+#' the \code{windowDuration}.
+#' @param startTime the offset with respect to 1970-01-01 00:00:00 UTC with which to start
+#' window intervals. For example, in order to have hourly tumbling windows
+#' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide
+#' \code{startTime} as \code{"15 minutes"}.
+#' @param ... further arguments to be passed to or from other methods.
+#' @return An output column of struct called 'window' by default with the nested columns 'start'
+#' and 'end'.
#' @family datetime_funcs
#' @rdname window
#' @name window
@@ -2735,7 +2745,7 @@ setMethod("from_unixtime", signature(x = "Column"),
#'}
#' @note window since 2.0.0
setMethod("window", signature(x = "Column"),
- function(x, windowDuration, slideDuration = NULL, startTime = NULL) {
+ function(x, windowDuration, slideDuration = NULL, startTime = NULL, ...) {
stopifnot(is.character(windowDuration))
if (!is.null(slideDuration) && !is.null(startTime)) {
stopifnot(is.character(slideDuration) && is.character(startTime))
@@ -2766,6 +2776,10 @@ setMethod("window", signature(x = "Column"),
#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr
#' could not be found in str.
#'
+#' @param substr a character string to be matched.
+#' @param str a Column where matches are sought for each entry.
+#' @param pos start position of search.
+#' @param ... further arguments to be passed to or from other methods.
#' @family string_funcs
#' @rdname locate
#' @aliases locate,character,Column-method
@@ -2774,7 +2788,7 @@ setMethod("window", signature(x = "Column"),
#' @examples \dontrun{locate('b', df$c, 1)}
#' @note locate since 1.5.0
setMethod("locate", signature(substr = "character", str = "Column"),
- function(substr, str, pos = 1) {
+ function(substr, str, pos = 1, ...) {
jc <- callJStatic("org.apache.spark.sql.functions",
"locate",
substr, str@jc, as.integer(pos))
@@ -2785,6 +2799,9 @@ setMethod("locate", signature(substr = "character", str = "Column"),
#'
#' Left-pad the string column with
#'
+#' @param x the string Column to be left-padded.
+#' @param len maximum length of each output result.
+#' @param pad a character string to be padded with.
#' @family string_funcs
#' @rdname lpad
#' @aliases lpad,Column,numeric,character-method
@@ -2804,6 +2821,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"),
#'
#' Generate a random column with i.i.d. samples from U[0.0, 1.0].
#'
+#' @param seed a random seed. Can be missing.
#' @family normal_funcs
#' @rdname rand
#' @name rand
@@ -2832,6 +2850,7 @@ setMethod("rand", signature(seed = "numeric"),
#'
#' Generate a column with i.i.d. samples from the standard normal distribution.
#'
+#' @param seed a random seed. Can be missing.
#' @family normal_funcs
#' @rdname randn
#' @name randn
@@ -2860,6 +2879,9 @@ setMethod("randn", signature(seed = "numeric"),
#'
#' Extract a specific(idx) group identified by a java regex, from the specified string column.
#'
+#' @param x a string Column.
+#' @param pattern a regular expression.
+#' @param idx a group index.
#' @family string_funcs
#' @rdname regexp_extract
#' @name regexp_extract
@@ -2880,6 +2902,9 @@ setMethod("regexp_extract",
#'
#' Replace all substrings of the specified string value that match regexp with rep.
#'
+#' @param x a string Column.
+#' @param pattern a regular expression.
+#' @param replacement a character string that a matched \code{pattern} is replaced with.
#' @family string_funcs
#' @rdname regexp_replace
#' @name regexp_replace
@@ -2900,6 +2925,9 @@ setMethod("regexp_replace",
#'
#' Right-padded with pad to a length of len.
#'
+#' @param x the string Column to be right-padded.
+#' @param len maximum length of each output result.
+#' @param pad a character string to be padded with.
#' @family string_funcs
#' @rdname rpad
#' @name rpad
@@ -2922,6 +2950,11 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"),
#' returned. If count is negative, every to the right of the final delimiter (counting from the
#' right) is returned. substring_index performs a case-sensitive match when searching for delim.
#'
+#' @param x a Column.
+#' @param delim a delimiter string.
+#' @param count number of occurrences of \code{delim} before the substring is returned.
+#' A positive number means counting from the left, while negative means
+#' counting from the right.
#' @family string_funcs
#' @rdname substring_index
#' @aliases substring_index,Column,character,numeric-method
@@ -2949,6 +2982,11 @@ setMethod("substring_index",
#' The translate will happen when any character in the string matching with the character
#' in the matchingString.
#'
+#' @param x a string Column.
+#' @param matchingString a source string where each character will be translated.
+#' @param replaceString a target string where each \code{matchingString} character will
+#' be replaced by the character in \code{replaceString}
+#' at the same location, if any.
#' @family string_funcs
#' @rdname translate
#' @name translate
@@ -2997,6 +3035,10 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"),
column(jc)
})
+#' @param x a Column of date, in string, date or timestamp type.
+#' @param format the target format. See
+#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{
+#' Customizing Formats} for available options.
#' @rdname unix_timestamp
#' @name unix_timestamp
#' @aliases unix_timestamp,Column,character-method
@@ -3012,6 +3054,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"),
#' Evaluates a list of conditions and returns one of multiple possible result expressions.
#' For unmatched expressions null is returned.
#'
+#' @param condition the condition to test on. Must be a Column expression.
+#' @param value result expression.
#' @family normal_funcs
#' @rdname when
#' @name when
@@ -3033,6 +3077,9 @@ setMethod("when", signature(condition = "Column", value = "ANY"),
#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied.
#' Otherwise \code{no} is returned for unmatched conditions.
#'
+#' @param test a Column expression that describes the condition.
+#' @param yes return values for \code{TRUE} elements of test.
+#' @param no return values for \code{FALSE} elements of test.
#' @family normal_funcs
#' @rdname ifelse
#' @name ifelse
@@ -3074,10 +3121,14 @@ setMethod("ifelse",
#' @family window_funcs
#' @aliases cume_dist,missing-method
#' @export
-#' @examples \dontrun{cume_dist()}
+#' @examples \dontrun{
+#' df <- createDataFrame(iris)
+#' ws <- orderBy(windowPartitionBy("Species"), "Sepal_Length")
+#' out <- select(df, over(cume_dist(), ws), df$Sepal_Length, df$Species)
+#' }
#' @note cume_dist since 1.6.0
setMethod("cume_dist",
- signature(x = "missing"),
+ signature(),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist")
column(jc)
@@ -3101,7 +3152,7 @@ setMethod("cume_dist",
#' @examples \dontrun{dense_rank()}
#' @note dense_rank since 1.6.0
setMethod("dense_rank",
- signature(x = "missing"),
+ signature(),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank")
column(jc)
@@ -3115,6 +3166,11 @@ setMethod("dense_rank",
#'
#' This is equivalent to the LAG function in SQL.
#'
+#' @param x the column as a character string or a Column to compute on.
+#' @param offset the number of rows back from the current row from which to obtain a value.
+#' If not specified, the default is 1.
+#' @param defaultValue default to use when the offset row does not exist.
+#' @param ... further arguments to be passed to or from other methods.
#' @rdname lag
#' @name lag
#' @aliases lag,characterOrColumn-method
@@ -3124,7 +3180,7 @@ setMethod("dense_rank",
#' @note lag since 1.6.0
setMethod("lag",
signature(x = "characterOrColumn"),
- function(x, offset, defaultValue = NULL) {
+ function(x, offset, defaultValue = NULL, ...) {
col <- if (class(x) == "Column") {
x@jc
} else {
@@ -3143,7 +3199,7 @@ setMethod("lag",
#' an `offset` of one will return the next row at any given point in the window partition.
#'
#' This is equivalent to the LEAD function in SQL.
-#'
+#'
#' @param x Column to compute on
#' @param offset Number of rows to offset
#' @param defaultValue (Optional) default value to use
@@ -3211,7 +3267,7 @@ setMethod("ntile",
#' @examples \dontrun{percent_rank()}
#' @note percent_rank since 1.6.0
setMethod("percent_rank",
- signature(x = "missing"),
+ signature(),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank")
column(jc)
@@ -3243,6 +3299,8 @@ setMethod("rank",
})
# Expose rank() in the R base package
+#' @param x a numeric, complex, character or logical vector.
+#' @param ... additional argument(s) passed to the method.
#' @name rank
#' @rdname rank
#' @aliases rank,ANY-method
@@ -3267,7 +3325,7 @@ setMethod("rank",
#' @examples \dontrun{row_number()}
#' @note row_number since 1.6.0
setMethod("row_number",
- signature(x = "missing"),
+ signature(),
function() {
jc <- callJStatic("org.apache.spark.sql.functions", "row_number")
column(jc)
@@ -3318,7 +3376,7 @@ setMethod("explode",
#' size
#'
#' Returns length of array or map.
-#'
+#'
#' @param x Column to compute on
#'
#' @rdname size
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index e7444ac2467d8..aa3e5deed55b1 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -23,9 +23,7 @@
setGeneric("aggregateRDD",
function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") })
-# @rdname cache-methods
-# @export
-setGeneric("cache", function(x) { standardGeneric("cache") })
+setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") })
# @rdname coalesce
# @seealso repartition
@@ -36,9 +34,7 @@ setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coales
# @export
setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") })
-# @rdname collect-methods
-# @export
-setGeneric("collect", function(x, ...) { standardGeneric("collect") })
+setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") })
# @rdname collect-methods
# @export
@@ -51,9 +47,15 @@ setGeneric("collectPartition",
standardGeneric("collectPartition")
})
-# @rdname nrow
+<<<<<<< HEAD
+setGeneric("countRDD", function(x) { standardGeneric("countRDD") })
+
+setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") })
+=======
+# @rdname count
# @export
setGeneric("count", function(x) { standardGeneric("count") })
+>>>>>>> SPARK-16508-branch-2.0-template
# @rdname countByValue
# @export
@@ -74,17 +76,20 @@ setGeneric("approxQuantile",
standardGeneric("approxQuantile")
})
-# @rdname distinct
-# @export
-setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") })
+setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") })
# @rdname filterRDD
# @export
setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") })
+<<<<<<< HEAD
+setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") })
+=======
+
# @rdname first
# @export
setGeneric("first", function(x, ...) { standardGeneric("first") })
+>>>>>>> SPARK-16508-branch-2.0-template
# @rdname flatMap
# @export
@@ -110,6 +115,8 @@ setGeneric("glom", function(x) { standardGeneric("glom") })
# @export
setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") })
+setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") })
+
# @rdname keyBy
# @export
setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") })
@@ -152,9 +159,7 @@ setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions")
# @export
setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") })
-# @rdname persist
-# @export
-setGeneric("persist", function(x, newLevel) { standardGeneric("persist") })
+setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") })
# @rdname pipeRDD
# @export
@@ -168,10 +173,7 @@ setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("piv
# @export
setGeneric("reduce", function(x, func) { standardGeneric("reduce") })
-# @rdname repartition
-# @seealso coalesce
-# @export
-setGeneric("repartition", function(x, ...) { standardGeneric("repartition") })
+setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") })
# @rdname sampleRDD
# @export
@@ -193,6 +195,8 @@ setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile
# @export
setGeneric("setName", function(x, name) { standardGeneric("setName") })
+setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") })
+
# @rdname sortBy
# @export
setGeneric("sortBy",
@@ -200,9 +204,7 @@ setGeneric("sortBy",
standardGeneric("sortBy")
})
-# @rdname take
-# @export
-setGeneric("take", function(x, num) { standardGeneric("take") })
+setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") })
# @rdname takeOrdered
# @export
@@ -223,9 +225,7 @@ setGeneric("top", function(x, num) { standardGeneric("top") })
# @export
setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") })
-# @rdname unpersist-methods
-# @export
-setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") })
+setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") })
# @rdname zipRDD
# @export
@@ -343,9 +343,7 @@ setGeneric("join", function(x, y, ...) { standardGeneric("join") })
# @export
setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") })
-#' @rdname partitionBy
-#' @export
-setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") })
+setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") })
# @rdname reduceByKey
# @seealso groupByKey
@@ -395,6 +393,9 @@ setGeneric("value", function(bcast) { standardGeneric("value") })
#################### SparkDataFrame Methods ########################
+#' @param x a SparkDataFrame or GroupedData.
+#' @param ... further arguments to be passed to or from other methods.
+#' @return A SparkDataFrame.
#' @rdname summarize
#' @export
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
@@ -414,6 +415,16 @@ setGeneric("as.data.frame",
#' @export
setGeneric("attach")
+#' @rdname cache
+#' @export
+setGeneric("cache", function(x) { standardGeneric("cache") })
+
+#' @rdname collect
+#' @export
+setGeneric("collect", function(x, ...) { standardGeneric("collect") })
+
+#' @param do.NULL currently not used.
+#' @param prefix currently not used.
#' @rdname columns
#' @export
setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") })
@@ -434,11 +445,23 @@ setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") })
#' @export
setGeneric("columns", function(x) {standardGeneric("columns") })
+#' @rdname nrow
+#' @export
+setGeneric("count", function(x) { standardGeneric("count") })
+
#' @rdname cov
+#' @param x a Column object or a SparkDataFrame.
+#' @param ... additional argument(s). If `x` is a Column object, a Column object
+#' should be provided. If `x` is a SparkDataFrame, two column names should
+#' be provided.
#' @export
setGeneric("cov", function(x, ...) {standardGeneric("cov") })
#' @rdname corr
+#' @param x a Column object or a SparkDataFrame.
+#' @param ... additional argument(s). If `x` is a Column object, a Column object
+#' should be provided. If `x` is a SparkDataFrame, two column names should
+#' be provided.
#' @export
setGeneric("corr", function(x, ...) {standardGeneric("corr") })
@@ -465,10 +488,14 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
#' @export
setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") })
+#' @param x a SparkDataFrame or GroupedData.
+#' @param ... additional argument(s) passed to the method.
#' @rdname gapply
#' @export
setGeneric("gapply", function(x, ...) { standardGeneric("gapply") })
+#' @param x a SparkDataFrame or GroupedData.
+#' @param ... additional argument(s) passed to the method.
#' @rdname gapplyCollect
#' @export
setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") })
@@ -477,6 +504,10 @@ setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect")
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
+#' @rdname distinct
+#' @export
+setGeneric("distinct", function(x) { standardGeneric("distinct") })
+
#' @rdname drop
#' @export
setGeneric("drop", function(x, ...) { standardGeneric("drop") })
@@ -519,6 +550,10 @@ setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna")
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
+#' @rdname first
+#' @export
+setGeneric("first", function(x, ...) { standardGeneric("first") })
+
#' @rdname groupBy
#' @export
setGeneric("group_by", function(x, ...) { standardGeneric("group_by") })
@@ -551,21 +586,29 @@ setGeneric("merge")
#' @export
setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") })
-#' @rdname arrange
+#' @rdname orderBy
#' @export
setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") })
+#' @rdname persist
+#' @export
+setGeneric("persist", function(x, newLevel) { standardGeneric("persist") })
+
#' @rdname printSchema
#' @export
setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
+#' @rdname registerTempTable-deprecated
+#' @export
+setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
+
#' @rdname rename
#' @export
setGeneric("rename", function(x, ...) { standardGeneric("rename") })
-#' @rdname registerTempTable-deprecated
+#' @rdname repartition
#' @export
-setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
+setGeneric("repartition", function(x, ...) { standardGeneric("repartition") })
#' @rdname sample
#' @export
@@ -592,6 +635,10 @@ setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error",
#' @export
setGeneric("str")
+#' @rdname take
+#' @export
+setGeneric("take", function(x, num) { standardGeneric("take") })
+
#' @rdname mutate
#' @export
setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") })
@@ -650,8 +697,8 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
#' @export
setGeneric("showDF", function(x, ...) { standardGeneric("showDF") })
-# @rdname subset
-# @export
+#' @rdname subset
+#' @export
setGeneric("subset", function(x, ...) { standardGeneric("subset") })
#' @rdname summarize
@@ -674,6 +721,10 @@ setGeneric("union", function(x, y) { standardGeneric("union") })
#' @export
setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") })
+#' @rdname unpersist-methods
+#' @export
+setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") })
+
#' @rdname filter
#' @export
setGeneric("where", function(x, condition) { standardGeneric("where") })
@@ -714,6 +765,8 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") })
setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
#' @rdname columnfunctions
+#' @param x a Column object.
+#' @param ... additional argument(s).
#' @export
setGeneric("contains", function(x, ...) { standardGeneric("contains") })
@@ -771,6 +824,10 @@ setGeneric("over", function(x, window) { standardGeneric("over") })
###################### WindowSpec Methods ##########################
+#' @rdname partitionBy
+#' @export
+setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") })
+
#' @rdname rowsBetween
#' @export
setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") })
@@ -805,6 +862,8 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain
#' @export
setGeneric("ascii", function(x) { standardGeneric("ascii") })
+#' @param x Column to compute on or a GroupedData object.
+#' @param ... additional argument(s) when `x` is a GroupedData object.
#' @rdname avg
#' @export
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
@@ -861,9 +920,10 @@ setGeneric("crc32", function(x) { standardGeneric("crc32") })
#' @export
setGeneric("hash", function(x, ...) { standardGeneric("hash") })
+#' @param ... empty. Use with no argument.
#' @rdname cume_dist
#' @export
-setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") })
+setGeneric("cume_dist", function(...) { standardGeneric("cume_dist") })
#' @rdname datediff
#' @export
@@ -893,9 +953,10 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
#' @export
setGeneric("decode", function(x, charset) { standardGeneric("decode") })
+#' @param ... empty. Use with no argument.
#' @rdname dense_rank
#' @export
-setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") })
+setGeneric("dense_rank", function(...) { standardGeneric("dense_rank") })
#' @rdname encode
#' @export
@@ -1009,10 +1070,11 @@ setGeneric("md5", function(x) { standardGeneric("md5") })
#' @export
setGeneric("minute", function(x) { standardGeneric("minute") })
+#' @param ... empty. Use with no argument.
#' @rdname monotonically_increasing_id
#' @export
setGeneric("monotonically_increasing_id",
- function(x) { standardGeneric("monotonically_increasing_id") })
+ function(...) { standardGeneric("monotonically_increasing_id") })
#' @rdname month
#' @export
@@ -1022,7 +1084,7 @@ setGeneric("month", function(x) { standardGeneric("month") })
#' @export
setGeneric("months_between", function(y, x) { standardGeneric("months_between") })
-#' @rdname nrow
+#' @rdname n
#' @export
setGeneric("n", function(x) { standardGeneric("n") })
@@ -1046,9 +1108,10 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") })
#' @export
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
+#' @param ... empty. Use with no argument.
#' @rdname percent_rank
#' @export
-setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") })
+setGeneric("percent_rank", function(...) { standardGeneric("percent_rank") })
#' @rdname pmod
#' @export
@@ -1089,11 +1152,12 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") })
#' @rdname rint
#' @export
-setGeneric("rint", function(x, ...) { standardGeneric("rint") })
+setGeneric("rint", function(x) { standardGeneric("rint") })
+#' @param ... empty. Use with no argument.
#' @rdname row_number
#' @export
-setGeneric("row_number", function(x) { standardGeneric("row_number") })
+setGeneric("row_number", function(...) { standardGeneric("row_number") })
#' @rdname rpad
#' @export
@@ -1151,9 +1215,10 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array")
#' @export
setGeneric("soundex", function(x) { standardGeneric("soundex") })
+#' @param ... empty. Use with no argument.
#' @rdname spark_partition_id
#' @export
-setGeneric("spark_partition_id", function(x) { standardGeneric("spark_partition_id") })
+setGeneric("spark_partition_id", function(...) { standardGeneric("spark_partition_id") })
#' @rdname sd
#' @export
@@ -1255,6 +1320,8 @@ setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.gl
#' @export
setGeneric("glm")
+#' @param object a fitted ML model object.
+#' @param ... additional argument(s) passed to the method.
#' @rdname predict
#' @export
setGeneric("predict", function(object, ...) { standardGeneric("predict") })
@@ -1277,8 +1344,11 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @rdname spark.survreg
#' @export
-setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
+setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
+#' @param object a fitted ML model object.
+#' @param path the directory where the model is saved.
+#' @param ... additional argument(s) passed to the method.
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 85348ae76baa7..72c76c4e61d26 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -59,8 +59,8 @@ setMethod("show", "GroupedData",
#' Count the number of rows for each group.
#' The resulting SparkDataFrame will also contain the grouping columns.
#'
-#' @param x a GroupedData
-#' @return a SparkDataFrame
+#' @param x a GroupedData.
+#' @return A SparkDataFrame.
#' @rdname count
#' @aliases count,GroupedData-method
#' @export
@@ -83,8 +83,6 @@ setMethod("count",
#' df2 <- agg(df, = )
#' df2 <- agg(df, newColName = aggFunction(column))
#'
-#' @param x a GroupedData
-#' @return a SparkDataFrame
#' @rdname summarize
#' @aliases agg,GroupedData-method
#' @name agg
@@ -201,7 +199,6 @@ createMethods()
#' gapply
#'
-#' @param x A GroupedData
#' @rdname gapply
#' @aliases gapply,GroupedData-method
#' @name gapply
@@ -216,7 +213,6 @@ setMethod("gapply",
#' gapplyCollect
#'
-#' @param x A GroupedData
#' @rdname gapplyCollect
#' @aliases gapplyCollect,GroupedData-method
#' @name gapplyCollect
diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R
new file mode 100644
index 0000000000000..987bac7bebc0e
--- /dev/null
+++ b/R/pkg/R/install.R
@@ -0,0 +1,235 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Functions to install Spark in case the user directly downloads SparkR
+# from CRAN.
+
+#' Download and Install Apache Spark to a Local Directory
+#'
+#' \code{install.spark} downloads and installs Spark to a local directory if
+#' it is not found. The Spark version we use is the same as the SparkR version.
+#' Users can specify a desired Hadoop version, the remote mirror site, and
+#' the directory where the package is installed locally.
+#'
+#' The full url of remote file is inferred from \code{mirrorUrl} and \code{hadoopVersion}.
+#' \code{mirrorUrl} specifies the remote path to a Spark folder. It is followed by a subfolder
+#' named after the Spark version (that corresponds to SparkR), and then the tar filename.
+#' The filename is composed of four parts, i.e. [Spark version]-bin-[Hadoop version].tgz.
+#' For example, the full path for a Spark 2.0.0 package for Hadoop 2.7 from
+#' \code{http://apache.osuosl.org} has path:
+#' \code{http://apache.osuosl.org/spark/spark-2.0.0/spark-2.0.0-bin-hadoop2.7.tgz}.
+#' For \code{hadoopVersion = "without"}, [Hadoop version] in the filename is then
+#' \code{without-hadoop}.
+#'
+#' @param hadoopVersion Version of Hadoop to install. Default is \code{"2.7"}. It can take other
+#' version number in the format of "x.y" where x and y are integer.
+#' If \code{hadoopVersion = "without"}, "Hadoop free" build is installed.
+#' See
+#' \href{http://spark.apache.org/docs/latest/hadoop-provided.html}{
+#' "Hadoop Free" Build} for more information.
+#' Other patched version names can also be used, e.g. \code{"cdh4"}
+#' @param mirrorUrl base URL of the repositories to use. The directory layout should follow
+#' \href{http://www.apache.org/dyn/closer.lua/spark/}{Apache mirrors}.
+#' @param localDir a local directory where Spark is installed. The directory contains
+#' version-specific folders of Spark packages. Default is path to
+#' the cache directory:
+#' \itemize{
+#' \item Mac OS X: \file{~/Library/Caches/spark}
+#' \item Unix: \env{$XDG_CACHE_HOME} if defined, otherwise \file{~/.cache/spark}
+#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}. See
+#' \href{https://www.microsoft.com/security/portal/mmpc/shared/variables.aspx}{
+#' Windows Common Folder Variables} about \%LOCALAPPDATA\%
+#' }
+#' @param overwrite If \code{TRUE}, download and overwrite the existing tar file in localDir
+#' and force re-install Spark (in case the local directory or file is corrupted)
+#' @return \code{install.spark} returns the local directory where Spark is found or installed
+#' @rdname install.spark
+#' @name install.spark
+#' @aliases install.spark
+#' @export
+#' @examples
+#'\dontrun{
+#' install.spark()
+#'}
+#' @note install.spark since 2.1.0
+#' @seealso See available Hadoop versions:
+#' \href{http://spark.apache.org/downloads.html}{Apache Spark}
+install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL,
+ localDir = NULL, overwrite = FALSE) {
+ version <- paste0("spark-", packageVersion("SparkR"))
+ hadoopVersion <- tolower(hadoopVersion)
+ hadoopVersionName <- hadoop_version_name(hadoopVersion)
+ packageName <- paste(version, "bin", hadoopVersionName, sep = "-")
+ localDir <- ifelse(is.null(localDir), spark_cache_path(),
+ normalizePath(localDir, mustWork = FALSE))
+
+ if (is.na(file.info(localDir)$isdir)) {
+ dir.create(localDir, recursive = TRUE)
+ }
+
+ packageLocalDir <- file.path(localDir, packageName)
+
+ if (overwrite) {
+ message(paste0("Overwrite = TRUE: download and overwrite the tar file",
+ "and Spark package directory if they exist."))
+ }
+
+ # can use dir.exists(packageLocalDir) under R 3.2.0 or later
+ if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) {
+ fmt <- "Spark %s for Hadoop %s is found, and SPARK_HOME set to %s"
+ msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion),
+ packageLocalDir)
+ message(msg)
+ Sys.setenv(SPARK_HOME = packageLocalDir)
+ return(invisible(packageLocalDir))
+ }
+
+ packageLocalPath <- paste0(packageLocalDir, ".tgz")
+ tarExists <- file.exists(packageLocalPath)
+
+ if (tarExists && !overwrite) {
+ message("tar file found.")
+ } else {
+ robust_download_tar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath)
+ }
+
+ message(sprintf("Installing to %s", localDir))
+ untar(tarfile = packageLocalPath, exdir = localDir)
+ if (!tarExists || overwrite) {
+ unlink(packageLocalPath)
+ }
+ message("DONE.")
+ Sys.setenv(SPARK_HOME = packageLocalDir)
+ message(paste("SPARK_HOME set to", packageLocalDir))
+ invisible(packageLocalDir)
+}
+
+robust_download_tar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) {
+ # step 1: use user-provided url
+ if (!is.null(mirrorUrl)) {
+ msg <- sprintf("Use user-provided mirror site: %s.", mirrorUrl)
+ message(msg)
+ success <- direct_download_tar(mirrorUrl, version, hadoopVersion,
+ packageName, packageLocalPath)
+ if (success) return()
+ } else {
+ message("Mirror site not provided.")
+ }
+
+ # step 2: use url suggested from apache website
+ message("Looking for site suggested from apache website...")
+ mirrorUrl <- get_preferred_mirror(version, packageName)
+ if (!is.null(mirrorUrl)) {
+ success <- direct_download_tar(mirrorUrl, version, hadoopVersion,
+ packageName, packageLocalPath)
+ if (success) return()
+ } else {
+ message("Unable to find suggested mirror site.")
+ }
+
+ # step 3: use backup option
+ message("To use backup site...")
+ mirrorUrl <- default_mirror_url()
+ success <- direct_download_tar(mirrorUrl, version, hadoopVersion,
+ packageName, packageLocalPath)
+ if (success) {
+ return(packageLocalPath)
+ } else {
+ msg <- sprintf(paste("Unable to download Spark %s for Hadoop %s.",
+ "Please check network connection, Hadoop version,",
+ "or provide other mirror sites."),
+ version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion))
+ stop(msg)
+ }
+}
+
+get_preferred_mirror <- function(version, packageName) {
+ jsonUrl <- paste0("http://www.apache.org/dyn/closer.cgi?path=",
+ file.path("spark", version, packageName),
+ ".tgz&as_json=1")
+ textLines <- readLines(jsonUrl, warn = FALSE)
+ rowNum <- grep("\"preferred\"", textLines)
+ linePreferred <- textLines[rowNum]
+ matchInfo <- regexpr("\"[A-Za-z][A-Za-z0-9+-.]*://.+\"", linePreferred)
+ if (matchInfo != -1) {
+ startPos <- matchInfo + 1
+ endPos <- matchInfo + attr(matchInfo, "match.length") - 2
+ mirrorPreferred <- base::substr(linePreferred, startPos, endPos)
+ mirrorPreferred <- paste0(mirrorPreferred, "spark")
+ message(sprintf("Preferred mirror site found: %s", mirrorPreferred))
+ } else {
+ mirrorPreferred <- NULL
+ }
+ mirrorPreferred
+}
+
+direct_download_tar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) {
+ packageRemotePath <- paste0(
+ file.path(mirrorUrl, version, packageName), ".tgz")
+ fmt <- paste("Downloading Spark %s for Hadoop %s from:\n- %s")
+ msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion),
+ packageRemotePath)
+ message(msg)
+
+ isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath),
+ error = function(e) {
+ message(sprintf("Fetch failed from %s", mirrorUrl))
+ print(e)
+ TRUE
+ })
+ !isFail
+}
+
+default_mirror_url <- function() {
+ "http://www-us.apache.org/dist/spark"
+}
+
+hadoop_version_name <- function(hadoopVersion) {
+ if (hadoopVersion == "without") {
+ "without-hadoop"
+ } else if (grepl("^[0-9]+\\.[0-9]+$", hadoopVersion, perl = TRUE)) {
+ paste0("hadoop", hadoopVersion)
+ } else {
+ hadoopVersion
+ }
+}
+
+# The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and
+# adapt to Spark context
+spark_cache_path <- function() {
+ if (.Platform$OS.type == "windows") {
+ winAppPath <- Sys.getenv("%LOCALAPPDATA%", unset = NA)
+ if (is.na(winAppPath)) {
+ msg <- paste("%LOCALAPPDATA% not found.",
+ "Please define the environment variable",
+ "or restart and enter an installation path in localDir.")
+ stop(msg)
+ } else {
+ path <- file.path(winAppPath, "spark", "spark", "Cache")
+ }
+ } else if (.Platform$OS.type == "unix") {
+ if (Sys.info()["sysname"] == "Darwin") {
+ path <- file.path(Sys.getenv("HOME"), "Library/Caches", "spark")
+ } else {
+ path <- file.path(
+ Sys.getenv("XDG_CACHE_HOME", file.path(Sys.getenv("HOME"), ".cache")), "spark")
+ }
+ } else {
+ stop(sprintf("Unknown OS: %s", .Platform$OS.type))
+ }
+ normalizePath(path, mustWork = FALSE)
+}
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 50c601fcd9e1b..16b0a5d05680f 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -82,15 +82,16 @@ NULL
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#'
-#' @param data SparkDataFrame for training.
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param family A description of the error distribution and link function to be used in the model.
+#' @param family a description of the error distribution and link function to be used in the model.
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param tol Positive convergence tolerance of iterations.
-#' @param maxIter Integer giving the maximal number of IRLS iterations.
+#' @param tol positive convergence tolerance of iterations.
+#' @param maxIter integer giving the maximal number of IRLS iterations.
+#' @param ... additional arguments passed to the method.
#' @aliases spark.glm,SparkDataFrame,formula-method
#' @return \code{spark.glm} returns a fitted generalized linear model
#' @rdname spark.glm
@@ -142,15 +143,6 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' Generalized Linear Models (R-compliant)
#'
#' Fits a generalized linear model, similarly to R's glm().
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
-#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param data SparkDataFrame for training.
-#' @param family A description of the error distribution and link function to be used in the model.
-#' This can be a character string naming a family function, a family function or
-#' the result of a call to a family function. Refer R family at
-#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
-#' @param epsilon Positive convergence tolerance of iterations.
-#' @param maxit Integer giving the maximal number of IRLS iterations.
#' @return \code{glm} returns a fitted generalized linear model.
#' @rdname glm
#' @export
@@ -171,7 +163,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().
-#' @param object A fitted generalized linear model
+#' @param object a fitted generalized linear model.
#' @return \code{summary} returns a summary object of the fitted model, a list of components
#' including at least the coefficients, null/residual deviance, null/residual degrees
#' of freedom, AIC and number of iterations IRLS takes.
@@ -212,7 +204,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
# Prints the summary of GeneralizedLinearRegressionModel
#' @rdname spark.glm
-#' @param x Summary object of fitted generalized linear model returned by \code{summary} function
+#' @param x summary object of fitted generalized linear model returned by \code{summary} function
#' @export
#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0
print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
@@ -244,7 +236,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
# Makes predictions from a generalized linear model produced by glm() or spark.glm(),
# similarly to R's predict().
-#' @param newData SparkDataFrame for testing
+#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
#' "prediction"
#' @rdname spark.glm
@@ -258,7 +250,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(),
# similarly to R package e1071's predict.
-#' @param newData A SparkDataFrame for testing
+#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
#' "prediction"
#' @rdname spark.naiveBayes
@@ -271,9 +263,9 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes}
-#' @param object A naive Bayes model fitted by \code{spark.naiveBayes}
+#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}.
#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and
-#' \code{tables}, conditional probabilities given the target label
+#' \code{tables}, conditional probabilities given the target label.
#' @rdname spark.naiveBayes
#' @export
#' @note summary(NaiveBayesModel) since 2.0.0
@@ -298,14 +290,15 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#'
-#' @param data SparkDataFrame for training
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' Note that the response variable of formula is empty in spark.kmeans.
-#' @param k Number of centers
-#' @param maxIter Maximum iteration number
-#' @param initMode The initialization algorithm choosen to fit the model
-#' @return \code{spark.kmeans} returns a fitted k-means model
+#' @param ... additional argument(s) passed to the method.
+#' @param k number of centers.
+#' @param maxIter maximum iteration number.
+#' @param initMode the initialization algorithm choosen to fit the model.
+#' @return \code{spark.kmeans} returns a fitted k-means model.
#' @rdname spark.kmeans
#' @aliases spark.kmeans,SparkDataFrame,formula-method
#' @name spark.kmeans
@@ -346,8 +339,11 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
#' Get fitted result from a k-means model, similarly to R's fitted().
#' Note: A saved-loaded model does not support this method.
#'
-#' @param object A fitted k-means model
-#' @return \code{fitted} returns a SparkDataFrame containing fitted values
+#' @param object a fitted k-means model.
+#' @param method type of fitted results, \code{"centers"} for cluster centers
+#' or \code{"classes"} for assigned classes.
+#' @param ... additional argument(s) passed to the method.
+#' @return \code{fitted} returns a SparkDataFrame containing fitted values.
#' @rdname fitted
#' @export
#' @examples
@@ -371,8 +367,8 @@ setMethod("fitted", signature(object = "KMeansModel"),
# Get the summary of a k-means model
-#' @param object A fitted k-means model
-#' @return \code{summary} returns the model's coefficients, size and cluster
+#' @param object a fitted k-means model.
+#' @return \code{summary} returns the model's coefficients, size and cluster.
#' @rdname spark.kmeans
#' @export
#' @note summary(KMeansModel) since 2.0.0
@@ -398,7 +394,8 @@ setMethod("summary", signature(object = "KMeansModel"),
# Predicted values based on a k-means model
-#' @return \code{predict} returns the predicted values based on a k-means model
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns the predicted values based on a k-means model.
#' @rdname spark.kmeans
#' @export
#' @note predict(KMeansModel) since 2.0.0
@@ -414,11 +411,12 @@ setMethod("predict", signature(object = "KMeansModel"),
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
#' Only categorical data is supported.
#'
-#' @param data A \code{SparkDataFrame} of observations and labels for model fitting
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param smoothing Smoothing parameter
-#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model
+#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}.
+#' @param smoothing smoothing parameter.
+#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model.
#' @rdname spark.naiveBayes
#' @aliases spark.naiveBayes,SparkDataFrame,formula-method
#' @name spark.naiveBayes
@@ -445,7 +443,7 @@ setMethod("predict", signature(object = "KMeansModel"),
#' }
#' @note spark.naiveBayes since 2.0.0
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, smoothing = 1.0, ...) {
+ function(data, formula, smoothing = 1.0) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
formula, data@sdf, smoothing)
@@ -454,8 +452,8 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
# Saves the Bernoulli naive Bayes model to the input path.
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.naiveBayes
@@ -473,10 +471,9 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
# Saves the AFT survival regression model to the input path.
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
-#'
#' @rdname spark.survreg
#' @export
#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0
@@ -492,8 +489,8 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c
# Saves the generalized linear model to the input path.
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.glm
@@ -510,8 +507,8 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
# Save fitted MLlib model to the input path
-#' @param path The directory where the model is saved
-#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.kmeans
@@ -528,8 +525,8 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
#' Load a fitted MLlib model from the input path.
#'
-#' @param path Path of the model to read.
-#' @return a fitted MLlib model
+#' @param path path of the model to read.
+#' @return A fitted MLlib model.
#' @rdname read.ml
#' @name read.ml
#' @export
@@ -563,11 +560,11 @@ read.ml <- function(path) {
#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
#' save/load fitted models.
#'
-#' @param data A SparkDataFrame for training
-#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
-#' Note that operator '.' is not supported currently
-#' @return \code{spark.survreg} returns a fitted AFT survival regression model
+#' Note that operator '.' is not supported currently.
+#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
#' @export
@@ -591,7 +588,7 @@ read.ml <- function(path) {
#' }
#' @note spark.survreg since 2.0.0
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, ...) {
+ function(data, formula) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
@@ -602,14 +599,14 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
# Returns a summary of the AFT survival regression model produced by spark.survreg,
# similarly to R's summary().
-#' @param object A fitted AFT survival regression model
+#' @param object a fitted AFT survival regression model.
#' @return \code{summary} returns a list containing the model's coefficients,
#' intercept and log(scale)
#' @rdname spark.survreg
#' @export
#' @note summary(AFTSurvivalRegressionModel) since 2.0.0
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
- function(object, ...) {
+ function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
@@ -622,9 +619,9 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
# Makes predictions from an AFT survival regression model or a model produced by
# spark.survreg, similarly to R package survival's predict.
-#' @param newData A SparkDataFrame for testing
+#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted values
-#' on the original scale of the data (mean predicted value at scale = 1.0)
+#' on the original scale of the data (mean predicted value at scale = 1.0).
#' @rdname spark.survreg
#' @export
#' @note predict(AFTSurvivalRegressionModel) since 2.0.0
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index d39775cabef88..f0605db1e9e83 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -49,7 +49,7 @@ setMethod("lookup",
lapply(filtered, function(i) { i[[2]] })
}
valsRDD <- lapplyPartition(x, partitionFunc)
- collect(valsRDD)
+ collectRDD(valsRDD)
})
#' Count the number of elements for each key, and return the result to the
@@ -85,7 +85,7 @@ setMethod("countByKey",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)))
-#' collect(keys(rdd)) # list(1, 3)
+#' collectRDD(keys(rdd)) # list(1, 3)
#'}
# nolint end
#' @rdname keys
@@ -108,7 +108,7 @@ setMethod("keys",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)))
-#' collect(values(rdd)) # list(2, 4)
+#' collectRDD(values(rdd)) # list(2, 4)
#'}
# nolint end
#' @rdname values
@@ -135,7 +135,7 @@ setMethod("values",
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10)
#' makePairs <- lapply(rdd, function(x) { list(x, x) })
-#' collect(mapValues(makePairs, function(x) { x * 2) })
+#' collectRDD(mapValues(makePairs, function(x) { x * 2) })
#' Output: list(list(1,2), list(2,4), list(3,6), ...)
#'}
#' @rdname mapValues
@@ -162,7 +162,7 @@ setMethod("mapValues",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4))))
-#' collect(flatMapValues(rdd, function(x) { x }))
+#' collectRDD(flatMapValues(rdd, function(x) { x }))
#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4))
#'}
#' @rdname flatMapValues
@@ -198,13 +198,13 @@ setMethod("flatMapValues",
#' sc <- sparkR.init()
#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
#' rdd <- parallelize(sc, pairs)
-#' parts <- partitionBy(rdd, 2L)
+#' parts <- partitionByRDD(rdd, 2L)
#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4)
#'}
#' @rdname partitionBy
#' @aliases partitionBy,RDD,integer-method
#' @noRd
-setMethod("partitionBy",
+setMethod("partitionByRDD",
signature(x = "RDD"),
function(x, numPartitions, partitionFunc = hashCode) {
stopifnot(is.numeric(numPartitions))
@@ -261,7 +261,7 @@ setMethod("partitionBy",
#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
#' rdd <- parallelize(sc, pairs)
#' parts <- groupByKey(rdd, 2L)
-#' grouped <- collect(parts)
+#' grouped <- collectRDD(parts)
#' grouped[[1]] # Should be a list(1, list(2, 4))
#'}
#' @rdname groupByKey
@@ -270,7 +270,7 @@ setMethod("partitionBy",
setMethod("groupByKey",
signature(x = "RDD", numPartitions = "numeric"),
function(x, numPartitions) {
- shuffled <- partitionBy(x, numPartitions)
+ shuffled <- partitionByRDD(x, numPartitions)
groupVals <- function(part) {
vals <- new.env()
keys <- new.env()
@@ -321,7 +321,7 @@ setMethod("groupByKey",
#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
#' rdd <- parallelize(sc, pairs)
#' parts <- reduceByKey(rdd, "+", 2L)
-#' reduced <- collect(parts)
+#' reduced <- collectRDD(parts)
#' reduced[[1]] # Should be a list(1, 6)
#'}
#' @rdname reduceByKey
@@ -342,7 +342,7 @@ setMethod("reduceByKey",
convertEnvsToList(keys, vals)
}
locallyReduced <- lapplyPartition(x, reduceVals)
- shuffled <- partitionBy(locallyReduced, numToInt(numPartitions))
+ shuffled <- partitionByRDD(locallyReduced, numToInt(numPartitions))
lapplyPartition(shuffled, reduceVals)
})
@@ -430,7 +430,7 @@ setMethod("reduceByKeyLocally",
#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
#' rdd <- parallelize(sc, pairs)
#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L)
-#' combined <- collect(parts)
+#' combined <- collectRDD(parts)
#' combined[[1]] # Should be a list(1, 6)
#'}
# nolint end
@@ -453,7 +453,7 @@ setMethod("combineByKey",
convertEnvsToList(keys, combiners)
}
locallyCombined <- lapplyPartition(x, combineLocally)
- shuffled <- partitionBy(locallyCombined, numToInt(numPartitions))
+ shuffled <- partitionByRDD(locallyCombined, numToInt(numPartitions))
mergeAfterShuffle <- function(part) {
combiners <- new.env()
keys <- new.env()
@@ -563,13 +563,13 @@ setMethod("foldByKey",
#' sc <- sparkR.init()
#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
-#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3))
+#' joinRDD(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3))
#'}
# nolint end
#' @rdname join-methods
#' @aliases join,RDD,RDD-method
#' @noRd
-setMethod("join",
+setMethod("joinRDD",
signature(x = "RDD", y = "RDD"),
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
@@ -772,7 +772,7 @@ setMethod("cogroup",
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3)))
-#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1))
+#' collectRDD(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1))
#'}
# nolint end
#' @rdname sortByKey
@@ -784,12 +784,12 @@ setMethod("sortByKey",
rangeBounds <- list()
if (numPartitions > 1) {
- rddSize <- count(x)
+ rddSize <- countRDD(x)
# constant from Spark's RangePartitioner
maxSampleSize <- numPartitions * 20
fraction <- min(maxSampleSize / max(rddSize, 1), 1.0)
- samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L)))
+ samples <- collectRDD(keys(sampleRDD(x, FALSE, fraction, 1L)))
# Note: the built-in R sort() function only works on atomic vectors
samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending)
@@ -822,7 +822,7 @@ setMethod("sortByKey",
sortKeyValueList(part, decreasing = !ascending)
}
- newRDD <- partitionBy(x, numPartitions, rangePartitionFunc)
+ newRDD <- partitionByRDD(x, numPartitions, rangePartitionFunc)
lapplyPartition(newRDD, partitionFunc)
})
@@ -841,7 +841,7 @@ setMethod("sortByKey",
#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4),
#' list("b", 5), list("a", 2)))
#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1)))
-#' collect(subtractByKey(rdd1, rdd2))
+#' collectRDD(subtractByKey(rdd1, rdd2))
#' # list(list("b", 4), list("b", 5))
#'}
# nolint end
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index b429f5de13b87..cb5bdb90175bf 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -92,8 +92,9 @@ print.structType <- function(x, ...) {
#'
#' Create a structField object that contains the metadata for a single field in a schema.
#'
-#' @param x The name of the field
-#' @return a structField object
+#' @param x the name of the field.
+#' @param ... additional argument(s) passed to the method.
+#' @return A structField object.
#' @rdname structField
#' @export
#' @examples
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 524f7c4a26b67..85815af1f3639 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -320,14 +320,15 @@ sparkRHive.init <- function(jsc = NULL) {
#' For details on how to initialize and use SparkR, refer to SparkR programming guide at
#' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}.
#'
-#' @param master The Spark master URL
-#' @param appName Application name to register with cluster manager
-#' @param sparkHome Spark Home directory
-#' @param sparkConfig Named list of Spark configuration to set on worker nodes
-#' @param sparkJars Character vector of jar files to pass to the worker nodes
-#' @param sparkPackages Character vector of packages from spark-packages.org
-#' @param enableHiveSupport Enable support for Hive, fallback if not built with Hive support; once
+#' @param master the Spark master URL.
+#' @param appName application name to register with cluster manager.
+#' @param sparkHome Spark Home directory.
+#' @param sparkConfig named list of Spark configuration to set on worker nodes.
+#' @param sparkJars character vector of jar files to pass to the worker nodes.
+#' @param sparkPackages character vector of packages from spark-packages.org
+#' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once
#' set, this cannot be turned off on an existing session
+#' @param ... named Spark properties passed to the method.
#' @export
#' @examples
#'\dontrun{
@@ -365,6 +366,23 @@ sparkR.session <- function(
}
overrideEnvs(sparkConfigMap, paramMap)
}
+ # do not download if it is run in the sparkR shell
+ if (!nzchar(master) || is_master_local(master)) {
+ if (!is_sparkR_shell()) {
+ if (is.na(file.info(sparkHome)$isdir)) {
+ msg <- paste0("Spark not found in SPARK_HOME: ",
+ sparkHome,
+ " .\nTo search in the cache directory. ",
+ "Installation will start if not found.")
+ message(msg)
+ packageLocalDir <- install.spark()
+ sparkHome <- packageLocalDir
+ } else {
+ msg <- paste0("Spark package is found in SPARK_HOME: ", sparkHome)
+ message(msg)
+ }
+ }
+ }
if (!exists(".sparkRjsc", envir = .sparkREnv)) {
sparkExecutorEnvMap <- new.env()
@@ -396,9 +414,9 @@ sparkR.session <- function(
#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a
#' different value or cleared.
#'
-#' @param groupid the ID to be assigned to job groups
-#' @param description description for the job group ID
-#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation
+#' @param groupId the ID to be assigned to job groups.
+#' @param description description for the job group ID.
+#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation.
#' @rdname setJobGroup
#' @name setJobGroup
#' @examples
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 2b4ce195cbddb..8ea24d81729ec 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -25,6 +25,7 @@ setOldClass("jobj")
#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
#' non-zero pair frequencies will be returned.
#'
+#' @param x a SparkDataFrame
#' @param col1 name of the first column. Distinct items will make the first item of each row.
#' @param col2 name of the second column. Distinct items will make the column names of the output.
#' @return a local R data.frame representing the contingency table. The first column of each row
@@ -53,10 +54,9 @@ setMethod("crosstab",
#' Calculate the sample covariance of two numerical columns of a SparkDataFrame.
#'
-#' @param x A SparkDataFrame
-#' @param col1 the name of the first column
-#' @param col2 the name of the second column
-#' @return the covariance of the two columns.
+#' @param colName1 the name of the first column
+#' @param colName2 the name of the second column
+#' @return The covariance of the two columns.
#'
#' @rdname cov
#' @name cov
@@ -71,19 +71,18 @@ setMethod("crosstab",
#' @note cov since 1.6.0
setMethod("cov",
signature(x = "SparkDataFrame"),
- function(x, col1, col2) {
- stopifnot(class(col1) == "character" && class(col2) == "character")
+ function(x, colName1, colName2) {
+ stopifnot(class(colName1) == "character" && class(colName2) == "character")
statFunctions <- callJMethod(x@sdf, "stat")
- callJMethod(statFunctions, "cov", col1, col2)
+ callJMethod(statFunctions, "cov", colName1, colName2)
})
#' Calculates the correlation of two columns of a SparkDataFrame.
#' Currently only supports the Pearson Correlation Coefficient.
#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics.
#'
-#' @param x A SparkDataFrame
-#' @param col1 the name of the first column
-#' @param col2 the name of the second column
+#' @param colName1 the name of the first column
+#' @param colName2 the name of the second column
#' @param method Optional. A character specifying the method for calculating the correlation.
#' only "pearson" is allowed now.
#' @return The Pearson Correlation Coefficient as a Double.
@@ -102,10 +101,10 @@ setMethod("cov",
#' @note corr since 1.6.0
setMethod("corr",
signature(x = "SparkDataFrame"),
- function(x, col1, col2, method = "pearson") {
- stopifnot(class(col1) == "character" && class(col2) == "character")
+ function(x, colName1, colName2, method = "pearson") {
+ stopifnot(class(colName1) == "character" && class(colName2) == "character")
statFunctions <- callJMethod(x@sdf, "stat")
- callJMethod(statFunctions, "corr", col1, col2, method)
+ callJMethod(statFunctions, "corr", colName1, colName2, method)
})
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 240b9f669bdd7..d78c0a7a539a8 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -689,3 +689,11 @@ getSparkContext <- function() {
sc <- get(".sparkRjsc", envir = .sparkREnv)
sc
}
+
+is_master_local <- function(master) {
+ grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE)
+}
+
+is_sparkR_shell <- function() {
+ grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
+}
diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R
index b69f017de81d1..f7a0510711da9 100644
--- a/R/pkg/inst/tests/testthat/test_binaryFile.R
+++ b/R/pkg/inst/tests/testthat/test_binaryFile.R
@@ -31,7 +31,7 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", {
rdd <- textFile(sc, fileName1, 1)
saveAsObjectFile(rdd, fileName2)
rdd <- objectFile(sc, fileName2)
- expect_equal(collect(rdd), as.list(mockFile))
+ expect_equal(collectRDD(rdd), as.list(mockFile))
unlink(fileName1)
unlink(fileName2, recursive = TRUE)
@@ -44,7 +44,7 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", {
rdd <- parallelize(sc, l, 1)
saveAsObjectFile(rdd, fileName)
rdd <- objectFile(sc, fileName)
- expect_equal(collect(rdd), l)
+ expect_equal(collectRDD(rdd), l)
unlink(fileName, recursive = TRUE)
})
@@ -64,7 +64,7 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works",
saveAsObjectFile(counts, fileName2)
counts <- objectFile(sc, fileName2)
- output <- collect(counts)
+ output <- collectRDD(counts)
expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1),
list("is", 2))
expect_equal(sortKeyValueList(output), sortKeyValueList(expected))
@@ -83,7 +83,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", {
saveAsObjectFile(rdd2, fileName2)
rdd <- objectFile(sc, c(fileName1, fileName2))
- expect_equal(count(rdd), 2)
+ expect_equal(countRDD(rdd), 2)
unlink(fileName1, recursive = TRUE)
unlink(fileName2, recursive = TRUE)
diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R
index 6f51d20687277..b780b9458545c 100644
--- a/R/pkg/inst/tests/testthat/test_binary_function.R
+++ b/R/pkg/inst/tests/testthat/test_binary_function.R
@@ -29,7 +29,7 @@ rdd <- parallelize(sc, nums, 2L)
mockFile <- c("Spark is pretty.", "Spark is awesome.")
test_that("union on two RDDs", {
- actual <- collect(unionRDD(rdd, rdd))
+ actual <- collectRDD(unionRDD(rdd, rdd))
expect_equal(actual, as.list(rep(nums, 2)))
fileName <- tempfile(pattern = "spark-test", fileext = ".tmp")
@@ -37,13 +37,13 @@ test_that("union on two RDDs", {
text.rdd <- textFile(sc, fileName)
union.rdd <- unionRDD(rdd, text.rdd)
- actual <- collect(union.rdd)
+ actual <- collectRDD(union.rdd)
expect_equal(actual, c(as.list(nums), mockFile))
expect_equal(getSerializedMode(union.rdd), "byte")
rdd <- map(text.rdd, function(x) {x})
union.rdd <- unionRDD(rdd, text.rdd)
- actual <- collect(union.rdd)
+ actual <- collectRDD(union.rdd)
expect_equal(actual, as.list(c(mockFile, mockFile)))
expect_equal(getSerializedMode(union.rdd), "byte")
@@ -54,14 +54,14 @@ test_that("cogroup on two RDDs", {
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
- actual <- collect(cogroup.rdd)
+ actual <- collectRDD(cogroup.rdd)
expect_equal(actual,
list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list()))))
rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4)))
rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3)))
cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
- actual <- collect(cogroup.rdd)
+ actual <- collectRDD(cogroup.rdd)
expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3))))
expect_equal(sortKeyValueList(actual),
@@ -72,7 +72,7 @@ test_that("zipPartitions() on RDDs", {
rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2
rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4
rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6
- actual <- collect(zipPartitions(rdd1, rdd2, rdd3,
+ actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3,
func = function(x, y, z) { list(list(x, y, z))} ))
expect_equal(actual,
list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6))))
@@ -82,19 +82,19 @@ test_that("zipPartitions() on RDDs", {
writeLines(mockFile, fileName)
rdd <- textFile(sc, fileName, 1)
- actual <- collect(zipPartitions(rdd, rdd,
+ actual <- collectRDD(zipPartitions(rdd, rdd,
func = function(x, y) { list(paste(x, y, sep = "\n")) }))
expected <- list(paste(mockFile, mockFile, sep = "\n"))
expect_equal(actual, expected)
rdd1 <- parallelize(sc, 0:1, 1)
- actual <- collect(zipPartitions(rdd1, rdd,
+ actual <- collectRDD(zipPartitions(rdd1, rdd,
func = function(x, y) { list(x + nchar(y)) }))
expected <- list(0:1 + nchar(mockFile))
expect_equal(actual, expected)
rdd <- map(rdd, function(x) { x })
- actual <- collect(zipPartitions(rdd, rdd1,
+ actual <- collectRDD(zipPartitions(rdd, rdd1,
func = function(x, y) { list(y + nchar(x)) }))
expect_equal(actual, expected)
diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R
index cf1d43277105e..064249a57aed4 100644
--- a/R/pkg/inst/tests/testthat/test_broadcast.R
+++ b/R/pkg/inst/tests/testthat/test_broadcast.R
@@ -32,7 +32,7 @@ test_that("using broadcast variable", {
useBroadcast <- function(x) {
sum(SparkR:::value(randomMatBr) * x)
}
- actual <- collect(lapply(rrdd, useBroadcast))
+ actual <- collectRDD(lapply(rrdd, useBroadcast))
expected <- list(sum(randomMat) * 1, sum(randomMat) * 2)
expect_equal(actual, expected)
})
@@ -43,7 +43,7 @@ test_that("without using broadcast variable", {
useBroadcast <- function(x) {
sum(randomMat * x)
}
- actual <- collect(lapply(rrdd, useBroadcast))
+ actual <- collectRDD(lapply(rrdd, useBroadcast))
expected <- list(sum(randomMat) * 1, sum(randomMat) * 2)
expect_equal(actual, expected)
})
diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R
index 2a1bd61b11118..66640c4b08459 100644
--- a/R/pkg/inst/tests/testthat/test_context.R
+++ b/R/pkg/inst/tests/testthat/test_context.R
@@ -58,7 +58,7 @@ test_that("repeatedly starting and stopping SparkR", {
for (i in 1:4) {
sc <- suppressWarnings(sparkR.init())
rdd <- parallelize(sc, 1:20, 2L)
- expect_equal(count(rdd), 20)
+ expect_equal(countRDD(rdd), 20)
suppressWarnings(sparkR.stop())
}
})
@@ -94,8 +94,9 @@ test_that("rdd GC across sparkR.stop", {
rm(rdd2)
gc()
- count(rdd3)
- count(rdd4)
+ countRDD(rdd3)
+ countRDD(rdd4)
+ sparkR.session.stop()
})
test_that("job group functions can be called", {
diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R
index d6a3766539c02..025eb9b9fc9d6 100644
--- a/R/pkg/inst/tests/testthat/test_includePackage.R
+++ b/R/pkg/inst/tests/testthat/test_includePackage.R
@@ -37,7 +37,7 @@ test_that("include inside function", {
}
data <- lapplyPartition(rdd, generateData)
- actual <- collect(data)
+ actual <- collectRDD(data)
}
})
@@ -53,6 +53,6 @@ test_that("use include package", {
includePackage(sc, plyr)
data <- lapplyPartition(rdd, generateData)
- actual <- collect(data)
+ actual <- collectRDD(data)
}
})
diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R
index f79a8a70aafb1..1b230554f7a0e 100644
--- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R
+++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R
@@ -67,22 +67,22 @@ test_that("parallelize() on simple vectors and lists returns an RDD", {
test_that("collect(), following a parallelize(), gives back the original collections", {
numVectorRDD <- parallelize(jsc, numVector, 10)
- expect_equal(collect(numVectorRDD), as.list(numVector))
+ expect_equal(collectRDD(numVectorRDD), as.list(numVector))
numListRDD <- parallelize(jsc, numList, 1)
numListRDD2 <- parallelize(jsc, numList, 4)
- expect_equal(collect(numListRDD), as.list(numList))
- expect_equal(collect(numListRDD2), as.list(numList))
+ expect_equal(collectRDD(numListRDD), as.list(numList))
+ expect_equal(collectRDD(numListRDD2), as.list(numList))
strVectorRDD <- parallelize(jsc, strVector, 2)
strVectorRDD2 <- parallelize(jsc, strVector, 3)
- expect_equal(collect(strVectorRDD), as.list(strVector))
- expect_equal(collect(strVectorRDD2), as.list(strVector))
+ expect_equal(collectRDD(strVectorRDD), as.list(strVector))
+ expect_equal(collectRDD(strVectorRDD2), as.list(strVector))
strListRDD <- parallelize(jsc, strList, 4)
strListRDD2 <- parallelize(jsc, strList, 1)
- expect_equal(collect(strListRDD), as.list(strList))
- expect_equal(collect(strListRDD2), as.list(strList))
+ expect_equal(collectRDD(strListRDD), as.list(strList))
+ expect_equal(collectRDD(strListRDD2), as.list(strList))
})
test_that("regression: collect() following a parallelize() does not drop elements", {
@@ -90,7 +90,7 @@ test_that("regression: collect() following a parallelize() does not drop element
collLen <- 10
numPart <- 6
expected <- runif(collLen)
- actual <- collect(parallelize(jsc, expected, numPart))
+ actual <- collectRDD(parallelize(jsc, expected, numPart))
expect_equal(actual, as.list(expected))
})
@@ -99,12 +99,12 @@ test_that("parallelize() and collect() work for lists of pairs (pairwise data)",
numPairsRDDD1 <- parallelize(jsc, numPairs, 1)
numPairsRDDD2 <- parallelize(jsc, numPairs, 2)
numPairsRDDD3 <- parallelize(jsc, numPairs, 3)
- expect_equal(collect(numPairsRDDD1), numPairs)
- expect_equal(collect(numPairsRDDD2), numPairs)
- expect_equal(collect(numPairsRDDD3), numPairs)
+ expect_equal(collectRDD(numPairsRDDD1), numPairs)
+ expect_equal(collectRDD(numPairsRDDD2), numPairs)
+ expect_equal(collectRDD(numPairsRDDD3), numPairs)
# can also leave out the parameter name, if the params are supplied in order
strPairsRDDD1 <- parallelize(jsc, strPairs, 1)
strPairsRDDD2 <- parallelize(jsc, strPairs, 2)
- expect_equal(collect(strPairsRDDD1), strPairs)
- expect_equal(collect(strPairsRDDD2), strPairs)
+ expect_equal(collectRDD(strPairsRDDD1), strPairs)
+ expect_equal(collectRDD(strPairsRDDD2), strPairs)
})
diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R
index 429311d2924f0..d38a763bab8c6 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/inst/tests/testthat/test_rdd.R
@@ -34,14 +34,14 @@ test_that("get number of partitions in RDD", {
})
test_that("first on RDD", {
- expect_equal(first(rdd), 1)
+ expect_equal(firstRDD(rdd), 1)
newrdd <- lapply(rdd, function(x) x + 1)
- expect_equal(first(newrdd), 2)
+ expect_equal(firstRDD(newrdd), 2)
})
test_that("count and length on RDD", {
- expect_equal(count(rdd), 10)
- expect_equal(length(rdd), 10)
+ expect_equal(countRDD(rdd), 10)
+ expect_equal(lengthRDD(rdd), 10)
})
test_that("count by values and keys", {
@@ -57,40 +57,40 @@ test_that("count by values and keys", {
test_that("lapply on RDD", {
multiples <- lapply(rdd, function(x) { 2 * x })
- actual <- collect(multiples)
+ actual <- collectRDD(multiples)
expect_equal(actual, as.list(nums * 2))
})
test_that("lapplyPartition on RDD", {
sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) })
- actual <- collect(sums)
+ actual <- collectRDD(sums)
expect_equal(actual, list(15, 40))
})
test_that("mapPartitions on RDD", {
sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) })
- actual <- collect(sums)
+ actual <- collectRDD(sums)
expect_equal(actual, list(15, 40))
})
test_that("flatMap() on RDDs", {
flat <- flatMap(intRdd, function(x) { list(x, x) })
- actual <- collect(flat)
+ actual <- collectRDD(flat)
expect_equal(actual, rep(intPairs, each = 2))
})
test_that("filterRDD on RDD", {
filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 })
- actual <- collect(filtered.rdd)
+ actual <- collectRDD(filtered.rdd)
expect_equal(actual, list(2, 4, 6, 8, 10))
filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd)
- actual <- collect(filtered.rdd)
+ actual <- collectRDD(filtered.rdd)
expect_equal(actual, list(list(1L, -1)))
# Filter out all elements.
filtered.rdd <- filterRDD(rdd, function(x) { x > 10 })
- actual <- collect(filtered.rdd)
+ actual <- collectRDD(filtered.rdd)
expect_equal(actual, list())
})
@@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", {
part <- as.list(unlist(part) * partIndex + i)
})
rdd2 <- lapply(rdd2, function(x) x + x)
- actual <- collect(rdd2)
+ actual <- collectRDD(rdd2)
expected <- list(24, 24, 24, 24, 24,
168, 170, 172, 174, 176)
expect_equal(actual, expected)
@@ -126,20 +126,20 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp
part <- as.list(unlist(part) * partIndex)
})
- cache(rdd2)
+ cacheRDD(rdd2)
expect_true(rdd2@env$isCached)
rdd2 <- lapply(rdd2, function(x) x)
expect_false(rdd2@env$isCached)
- unpersist(rdd2)
+ unpersistRDD(rdd2)
expect_false(rdd2@env$isCached)
- persist(rdd2, "MEMORY_AND_DISK")
+ persistRDD(rdd2, "MEMORY_AND_DISK")
expect_true(rdd2@env$isCached)
rdd2 <- lapply(rdd2, function(x) x)
expect_false(rdd2@env$isCached)
- unpersist(rdd2)
+ unpersistRDD(rdd2)
expect_false(rdd2@env$isCached)
tempDir <- tempfile(pattern = "checkpoint")
@@ -152,7 +152,7 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp
expect_false(rdd2@env$isCheckpointed)
# make sure the data is collectable
- collect(rdd2)
+ collectRDD(rdd2)
unlink(tempDir)
})
@@ -169,21 +169,21 @@ test_that("reduce on RDD", {
test_that("lapply with dependency", {
fa <- 5
multiples <- lapply(rdd, function(x) { fa * x })
- actual <- collect(multiples)
+ actual <- collectRDD(multiples)
expect_equal(actual, as.list(nums * 5))
})
test_that("lapplyPartitionsWithIndex on RDDs", {
func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) }
- actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE)
+ actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE)
expect_equal(actual, list(list(0, 15), list(1, 40)))
pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L)
partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 }
mkTup <- function(partIndex, part) { list(partIndex, part) }
- actual <- collect(lapplyPartitionsWithIndex(
- partitionBy(pairsRDD, 2L, partitionByParity),
+ actual <- collectRDD(lapplyPartitionsWithIndex(
+ partitionByRDD(pairsRDD, 2L, partitionByParity),
mkTup),
FALSE)
expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))),
@@ -191,7 +191,7 @@ test_that("lapplyPartitionsWithIndex on RDDs", {
})
test_that("sampleRDD() on RDDs", {
- expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums)
+ expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums)
})
test_that("takeSample() on RDDs", {
@@ -238,7 +238,7 @@ test_that("takeSample() on RDDs", {
test_that("mapValues() on pairwise RDDs", {
multiples <- mapValues(intRdd, function(x) { x * 2 })
- actual <- collect(multiples)
+ actual <- collectRDD(multiples)
expected <- lapply(intPairs, function(x) {
list(x[[1]], x[[2]] * 2)
})
@@ -247,11 +247,11 @@ test_that("mapValues() on pairwise RDDs", {
test_that("flatMapValues() on pairwise RDDs", {
l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4))))
- actual <- collect(flatMapValues(l, function(x) { x }))
+ actual <- collectRDD(flatMapValues(l, function(x) { x }))
expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4)))
# Generate x to x+1 for every value
- actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) }))
+ actual <- collectRDD(flatMapValues(intRdd, function(x) { x: (x + 1) }))
expect_equal(actual,
list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101),
list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201)))
@@ -273,8 +273,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", {
test_that("distinct() on RDDs", {
nums.rep2 <- rep(1:10, 2)
rdd.rep2 <- parallelize(sc, nums.rep2, 2L)
- uniques <- distinct(rdd.rep2)
- actual <- sort(unlist(collect(uniques)))
+ uniques <- distinctRDD(rdd.rep2)
+ actual <- sort(unlist(collectRDD(uniques)))
expect_equal(actual, nums)
})
@@ -296,7 +296,7 @@ test_that("sumRDD() on RDDs", {
test_that("keyBy on RDDs", {
func <- function(x) { x * x }
keys <- keyBy(rdd, func)
- actual <- collect(keys)
+ actual <- collectRDD(keys)
expect_equal(actual, lapply(nums, function(x) { list(func(x), x) }))
})
@@ -304,12 +304,12 @@ test_that("repartition/coalesce on RDDs", {
rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements
# repartition
- r1 <- repartition(rdd, 2)
+ r1 <- repartitionRDD(rdd, 2)
expect_equal(getNumPartitions(r1), 2L)
count <- length(collectPartition(r1, 0L))
expect_true(count >= 8 && count <= 12)
- r2 <- repartition(rdd, 6)
+ r2 <- repartitionRDD(rdd, 6)
expect_equal(getNumPartitions(r2), 6L)
count <- length(collectPartition(r2, 0L))
expect_true(count >= 0 && count <= 4)
@@ -323,12 +323,12 @@ test_that("repartition/coalesce on RDDs", {
test_that("sortBy() on RDDs", {
sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE)
- actual <- collect(sortedRdd)
+ actual <- collectRDD(sortedRdd)
expect_equal(actual, as.list(sort(nums, decreasing = TRUE)))
rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L)
sortedRdd2 <- sortBy(rdd2, function(x) { x * x })
- actual <- collect(sortedRdd2)
+ actual <- collectRDD(sortedRdd2)
expect_equal(actual, as.list(nums))
})
@@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", {
test_that("zipWithUniqueId() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
- actual <- collect(zipWithUniqueId(rdd))
+ actual <- collectRDD(zipWithUniqueId(rdd))
expected <- list(list("a", 0), list("b", 3), list("c", 1),
list("d", 4), list("e", 2))
expect_equal(actual, expected)
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
- actual <- collect(zipWithUniqueId(rdd))
+ actual <- collectRDD(zipWithUniqueId(rdd))
expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
@@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", {
test_that("zipWithIndex() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
- actual <- collect(zipWithIndex(rdd))
+ actual <- collectRDD(zipWithIndex(rdd))
expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
- actual <- collect(zipWithIndex(rdd))
+ actual <- collectRDD(zipWithIndex(rdd))
expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
@@ -408,35 +408,35 @@ test_that("zipWithIndex() on RDDs", {
test_that("glom() on RDD", {
rdd <- parallelize(sc, as.list(1:4), 2L)
- actual <- collect(glom(rdd))
+ actual <- collectRDD(glom(rdd))
expect_equal(actual, list(list(1, 2), list(3, 4)))
})
test_that("keys() on RDDs", {
keys <- keys(intRdd)
- actual <- collect(keys)
+ actual <- collectRDD(keys)
expect_equal(actual, lapply(intPairs, function(x) { x[[1]] }))
})
test_that("values() on RDDs", {
values <- values(intRdd)
- actual <- collect(values)
+ actual <- collectRDD(values)
expect_equal(actual, lapply(intPairs, function(x) { x[[2]] }))
})
test_that("pipeRDD() on RDDs", {
- actual <- collect(pipeRDD(rdd, "more"))
+ actual <- collectRDD(pipeRDD(rdd, "more"))
expected <- as.list(as.character(1:10))
expect_equal(actual, expected)
trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n"))
- actual <- collect(pipeRDD(trailed.rdd, "sort"))
+ actual <- collectRDD(pipeRDD(trailed.rdd, "sort"))
expected <- list("", "1", "2", "3")
expect_equal(actual, expected)
rev.nums <- 9:0
rev.rdd <- parallelize(sc, rev.nums, 2L)
- actual <- collect(pipeRDD(rev.rdd, "sort"))
+ actual <- collectRDD(pipeRDD(rev.rdd, "sort"))
expected <- as.list(as.character(c(5:9, 0:4)))
expect_equal(actual, expected)
})
@@ -444,7 +444,7 @@ test_that("pipeRDD() on RDDs", {
test_that("zipRDD() on RDDs", {
rdd1 <- parallelize(sc, 0:4, 2)
rdd2 <- parallelize(sc, 1000:1004, 2)
- actual <- collect(zipRDD(rdd1, rdd2))
+ actual <- collectRDD(zipRDD(rdd1, rdd2))
expect_equal(actual,
list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)))
@@ -453,17 +453,17 @@ test_that("zipRDD() on RDDs", {
writeLines(mockFile, fileName)
rdd <- textFile(sc, fileName, 1)
- actual <- collect(zipRDD(rdd, rdd))
+ actual <- collectRDD(zipRDD(rdd, rdd))
expected <- lapply(mockFile, function(x) { list(x, x) })
expect_equal(actual, expected)
rdd1 <- parallelize(sc, 0:1, 1)
- actual <- collect(zipRDD(rdd1, rdd))
+ actual <- collectRDD(zipRDD(rdd1, rdd))
expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) })
expect_equal(actual, expected)
rdd1 <- map(rdd, function(x) { x })
- actual <- collect(zipRDD(rdd, rdd1))
+ actual <- collectRDD(zipRDD(rdd, rdd1))
expected <- lapply(mockFile, function(x) { list(x, x) })
expect_equal(actual, expected)
@@ -472,7 +472,7 @@ test_that("zipRDD() on RDDs", {
test_that("cartesian() on RDDs", {
rdd <- parallelize(sc, 1:3)
- actual <- collect(cartesian(rdd, rdd))
+ actual <- collectRDD(cartesian(rdd, rdd))
expect_equal(sortKeyValueList(actual),
list(
list(1, 1), list(1, 2), list(1, 3),
@@ -481,7 +481,7 @@ test_that("cartesian() on RDDs", {
# test case where one RDD is empty
emptyRdd <- parallelize(sc, list())
- actual <- collect(cartesian(rdd, emptyRdd))
+ actual <- collectRDD(cartesian(rdd, emptyRdd))
expect_equal(actual, list())
mockFile <- c("Spark is pretty.", "Spark is awesome.")
@@ -489,7 +489,7 @@ test_that("cartesian() on RDDs", {
writeLines(mockFile, fileName)
rdd <- textFile(sc, fileName)
- actual <- collect(cartesian(rdd, rdd))
+ actual <- collectRDD(cartesian(rdd, rdd))
expected <- list(
list("Spark is awesome.", "Spark is pretty."),
list("Spark is awesome.", "Spark is awesome."),
@@ -498,7 +498,7 @@ test_that("cartesian() on RDDs", {
expect_equal(sortKeyValueList(actual), expected)
rdd1 <- parallelize(sc, 0:1)
- actual <- collect(cartesian(rdd1, rdd))
+ actual <- collectRDD(cartesian(rdd1, rdd))
expect_equal(sortKeyValueList(actual),
list(
list(0, "Spark is pretty."),
@@ -507,7 +507,7 @@ test_that("cartesian() on RDDs", {
list(1, "Spark is awesome.")))
rdd1 <- map(rdd, function(x) { x })
- actual <- collect(cartesian(rdd, rdd1))
+ actual <- collectRDD(cartesian(rdd, rdd1))
expect_equal(sortKeyValueList(actual), expected)
unlink(fileName)
@@ -518,24 +518,24 @@ test_that("subtract() on RDDs", {
rdd1 <- parallelize(sc, l)
# subtract by itself
- actual <- collect(subtract(rdd1, rdd1))
+ actual <- collectRDD(subtract(rdd1, rdd1))
expect_equal(actual, list())
# subtract by an empty RDD
rdd2 <- parallelize(sc, list())
- actual <- collect(subtract(rdd1, rdd2))
+ actual <- collectRDD(subtract(rdd1, rdd2))
expect_equal(as.list(sort(as.vector(actual, mode = "integer"))),
l)
rdd2 <- parallelize(sc, list(2, 4))
- actual <- collect(subtract(rdd1, rdd2))
+ actual <- collectRDD(subtract(rdd1, rdd2))
expect_equal(as.list(sort(as.vector(actual, mode = "integer"))),
list(1, 1, 3))
l <- list("a", "a", "b", "b", "c", "d")
rdd1 <- parallelize(sc, l)
rdd2 <- parallelize(sc, list("b", "d"))
- actual <- collect(subtract(rdd1, rdd2))
+ actual <- collectRDD(subtract(rdd1, rdd2))
expect_equal(as.list(sort(as.vector(actual, mode = "character"))),
list("a", "a", "c"))
})
@@ -546,17 +546,17 @@ test_that("subtractByKey() on pairwise RDDs", {
rdd1 <- parallelize(sc, l)
# subtractByKey by itself
- actual <- collect(subtractByKey(rdd1, rdd1))
+ actual <- collectRDD(subtractByKey(rdd1, rdd1))
expect_equal(actual, list())
# subtractByKey by an empty RDD
rdd2 <- parallelize(sc, list())
- actual <- collect(subtractByKey(rdd1, rdd2))
+ actual <- collectRDD(subtractByKey(rdd1, rdd2))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(l))
rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1)))
- actual <- collect(subtractByKey(rdd1, rdd2))
+ actual <- collectRDD(subtractByKey(rdd1, rdd2))
expect_equal(actual,
list(list("b", 4), list("b", 5)))
@@ -564,76 +564,76 @@ test_that("subtractByKey() on pairwise RDDs", {
list(2, 5), list(1, 2))
rdd1 <- parallelize(sc, l)
rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1)))
- actual <- collect(subtractByKey(rdd1, rdd2))
+ actual <- collectRDD(subtractByKey(rdd1, rdd2))
expect_equal(actual,
list(list(2, 4), list(2, 5)))
})
test_that("intersection() on RDDs", {
# intersection with self
- actual <- collect(intersection(rdd, rdd))
+ actual <- collectRDD(intersection(rdd, rdd))
expect_equal(sort(as.integer(actual)), nums)
# intersection with an empty RDD
emptyRdd <- parallelize(sc, list())
- actual <- collect(intersection(rdd, emptyRdd))
+ actual <- collectRDD(intersection(rdd, emptyRdd))
expect_equal(actual, list())
rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5))
rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8))
- actual <- collect(intersection(rdd1, rdd2))
+ actual <- collectRDD(intersection(rdd1, rdd2))
expect_equal(sort(as.integer(actual)), 1:3)
})
test_that("join() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
- actual <- collect(join(rdd1, rdd2, 2L))
+ actual <- collectRDD(joinRDD(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3)))))
rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4)))
rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3)))
- actual <- collect(join(rdd1, rdd2, 2L))
+ actual <- collectRDD(joinRDD(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3)))))
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2)))
rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4)))
- actual <- collect(join(rdd1, rdd2, 2L))
+ actual <- collectRDD(joinRDD(rdd1, rdd2, 2L))
expect_equal(actual, list())
rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2)))
rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4)))
- actual <- collect(join(rdd1, rdd2, 2L))
+ actual <- collectRDD(joinRDD(rdd1, rdd2, 2L))
expect_equal(actual, list())
})
test_that("leftOuterJoin() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
- actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L))
expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4)))
rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3)))
- actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L))
expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2)))
rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4)))
- actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L))
expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2)))
rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4)))
- actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L))
expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
@@ -642,26 +642,26 @@ test_that("leftOuterJoin() on pairwise RDDs", {
test_that("rightOuterJoin() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3)))
rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4)))
- actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L))
expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3)))
rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4)))
- actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L))
expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2)))
rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4)))
- actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2)))
rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4)))
- actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
})
@@ -669,14 +669,14 @@ test_that("rightOuterJoin() on pairwise RDDs", {
test_that("fullOuterJoin() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3)))
rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4)))
- actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L))
expected <- list(list(1, list(2, 1)), list(1, list(3, 1)),
list(2, list(NULL, 4)), list(3, list(3, NULL)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3), list("c", 1)))
rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4)))
- actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L))
expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)),
list("a", list(3, 1)), list("c", list(1, NULL)))
expect_equal(sortKeyValueList(actual),
@@ -684,14 +684,14 @@ test_that("fullOuterJoin() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2)))
rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4)))
- actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)),
list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2)))
rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4)))
- actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L))
expect_equal(sortKeyValueList(actual),
sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)),
list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
@@ -700,21 +700,21 @@ test_that("fullOuterJoin() on pairwise RDDs", {
test_that("sortByKey() on pairwise RDDs", {
numPairsRdd <- map(rdd, function(x) { list (x, x) })
sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE)
- actual <- collect(sortedRdd)
+ actual <- collectRDD(sortedRdd)
numPairs <- lapply(nums, function(x) { list (x, x) })
expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE))
rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L)
numPairsRdd2 <- map(rdd2, function(x) { list (x, x) })
sortedRdd2 <- sortByKey(numPairsRdd2)
- actual <- collect(sortedRdd2)
+ actual <- collectRDD(sortedRdd2)
expect_equal(actual, numPairs)
# sort by string keys
l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5))
rdd3 <- parallelize(sc, l, 2L)
sortedRdd3 <- sortByKey(rdd3)
- actual <- collect(sortedRdd3)
+ actual <- collectRDD(sortedRdd3)
expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
# test on the boundary cases
@@ -722,27 +722,27 @@ test_that("sortByKey() on pairwise RDDs", {
# boundary case 1: the RDD to be sorted has only 1 partition
rdd4 <- parallelize(sc, l, 1L)
sortedRdd4 <- sortByKey(rdd4)
- actual <- collect(sortedRdd4)
+ actual <- collectRDD(sortedRdd4)
expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
# boundary case 2: the sorted RDD has only 1 partition
rdd5 <- parallelize(sc, l, 2L)
sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L)
- actual <- collect(sortedRdd5)
+ actual <- collectRDD(sortedRdd5)
expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
# boundary case 3: the RDD to be sorted has only 1 element
l2 <- list(list("a", 1))
rdd6 <- parallelize(sc, l2, 2L)
sortedRdd6 <- sortByKey(rdd6)
- actual <- collect(sortedRdd6)
+ actual <- collectRDD(sortedRdd6)
expect_equal(actual, l2)
# boundary case 4: the RDD to be sorted has 0 element
l3 <- list()
rdd7 <- parallelize(sc, l3, 2L)
sortedRdd7 <- sortByKey(rdd7)
- actual <- collect(sortedRdd7)
+ actual <- collectRDD(sortedRdd7)
expect_equal(actual, l3)
})
@@ -766,7 +766,7 @@ test_that("collectAsMap() on a pairwise RDD", {
test_that("show()", {
rdd <- parallelize(sc, list(1:10))
- expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+")
+ expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+")
})
test_that("sampleByKey() on pairwise RDDs", {
diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R
index 7d4f342016441..07f3b02df6649 100644
--- a/R/pkg/inst/tests/testthat/test_shuffle.R
+++ b/R/pkg/inst/tests/testthat/test_shuffle.R
@@ -39,7 +39,7 @@ strListRDD <- parallelize(sc, strList, 4)
test_that("groupByKey for integers", {
grouped <- groupByKey(intRdd, 2L)
- actual <- collect(grouped)
+ actual <- collectRDD(grouped)
expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -48,7 +48,7 @@ test_that("groupByKey for integers", {
test_that("groupByKey for doubles", {
grouped <- groupByKey(doubleRdd, 2L)
- actual <- collect(grouped)
+ actual <- collectRDD(grouped)
expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -57,7 +57,7 @@ test_that("groupByKey for doubles", {
test_that("reduceByKey for ints", {
reduced <- reduceByKey(intRdd, "+", 2L)
- actual <- collect(reduced)
+ actual <- collectRDD(reduced)
expected <- list(list(2L, 101), list(1L, 199))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -65,7 +65,7 @@ test_that("reduceByKey for ints", {
test_that("reduceByKey for doubles", {
reduced <- reduceByKey(doubleRdd, "+", 2L)
- actual <- collect(reduced)
+ actual <- collectRDD(reduced)
expected <- list(list(1.5, 199), list(2.5, 101))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -74,7 +74,7 @@ test_that("reduceByKey for doubles", {
test_that("combineByKey for ints", {
reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L)
- actual <- collect(reduced)
+ actual <- collectRDD(reduced)
expected <- list(list(2L, 101), list(1L, 199))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -82,7 +82,7 @@ test_that("combineByKey for ints", {
test_that("combineByKey for doubles", {
reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L)
- actual <- collect(reduced)
+ actual <- collectRDD(reduced)
expected <- list(list(1.5, 199), list(2.5, 101))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -94,7 +94,7 @@ test_that("combineByKey for characters", {
list("other", 3L), list("max", 4L)), 2L)
reduced <- combineByKey(stringKeyRDD,
function(x) { x }, "+", "+", 2L)
- actual <- collect(reduced)
+ actual <- collectRDD(reduced)
expected <- list(list("max", 5L), list("min", 2L), list("other", 3L))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -109,7 +109,7 @@ test_that("aggregateByKey", {
combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
- actual <- collect(aggregatedRDD)
+ actual <- collectRDD(aggregatedRDD)
expected <- list(list(1, list(3, 2)), list(2, list(7, 2)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -122,7 +122,7 @@ test_that("aggregateByKey", {
combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
- actual <- collect(aggregatedRDD)
+ actual <- collectRDD(aggregatedRDD)
expected <- list(list("a", list(3, 2)), list("b", list(7, 2)))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -132,7 +132,7 @@ test_that("foldByKey", {
# test foldByKey for int keys
folded <- foldByKey(intRdd, 0, "+", 2L)
- actual <- collect(folded)
+ actual <- collectRDD(folded)
expected <- list(list(2L, 101), list(1L, 199))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -140,7 +140,7 @@ test_that("foldByKey", {
# test foldByKey for double keys
folded <- foldByKey(doubleRdd, 0, "+", 2L)
- actual <- collect(folded)
+ actual <- collectRDD(folded)
expected <- list(list(1.5, 199), list(2.5, 101))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -151,7 +151,7 @@ test_that("foldByKey", {
stringKeyRDD <- parallelize(sc, stringKeyPairs)
folded <- foldByKey(stringKeyRDD, 0, "+", 2L)
- actual <- collect(folded)
+ actual <- collectRDD(folded)
expected <- list(list("b", 101), list("a", 199))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -159,14 +159,14 @@ test_that("foldByKey", {
# test foldByKey for empty pair RDD
rdd <- parallelize(sc, list())
folded <- foldByKey(rdd, 0, "+", 2L)
- actual <- collect(folded)
+ actual <- collectRDD(folded)
expected <- list()
expect_equal(actual, expected)
# test foldByKey for RDD with only 1 pair
rdd <- parallelize(sc, list(list(1, 1)))
folded <- foldByKey(rdd, 0, "+", 2L)
- actual <- collect(folded)
+ actual <- collectRDD(folded)
expected <- list(list(1, 1))
expect_equal(actual, expected)
})
@@ -175,7 +175,7 @@ test_that("partitionBy() partitions data correctly", {
# Partition by magnitude
partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 }
- resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude)
+ resultRDD <- partitionByRDD(numPairsRdd, 2L, partitionByMagnitude)
expected_first <- list(list(1, 100), list(2, 200)) # key less than 3
expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3
@@ -191,7 +191,7 @@ test_that("partitionBy works with dependencies", {
partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 }
# Partition by parity
- resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity)
+ resultRDD <- partitionByRDD(numPairsRdd, numPartitions = 2L, partitionByParity)
# keys even; 100 %% 2 == 0
expected_first <- list(list(2, 200), list(4, -1))
@@ -208,7 +208,7 @@ test_that("test partitionBy with string keys", {
words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] })
wordCount <- lapply(words, function(word) { list(word, 1L) })
- resultRDD <- partitionBy(wordCount, 2L)
+ resultRDD <- partitionByRDD(wordCount, 2L)
expected_first <- list(list("Dexter", 1), list("Dexter", 1))
expected_second <- list(list("and", 1), list("and", 1))
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7e59fdf4620e1..0aea89ddcb076 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -490,7 +490,7 @@ test_that("read/write json files", {
test_that("jsonRDD() on a RDD with json string", {
sqlContext <- suppressWarnings(sparkRSQL.init(sc))
rdd <- parallelize(sc, mockLines)
- expect_equal(count(rdd), 3)
+ expect_equal(countRDD(rdd), 3)
df <- suppressWarnings(jsonRDD(sqlContext, rdd))
expect_is(df, "SparkDataFrame")
expect_equal(count(df), 3)
@@ -582,7 +582,7 @@ test_that("toRDD() returns an RRDD", {
df <- read.json(jsonPath)
testRDD <- toRDD(df)
expect_is(testRDD, "RDD")
- expect_equal(count(testRDD), 3)
+ expect_equal(countRDD(testRDD), 3)
})
test_that("union on two RDDs created from DataFrames returns an RRDD", {
@@ -592,7 +592,7 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", {
unioned <- unionRDD(RDD1, RDD2)
expect_is(unioned, "RDD")
expect_equal(getSerializedMode(unioned), "byte")
- expect_equal(collect(unioned)[[2]]$name, "Andy")
+ expect_equal(collectRDD(unioned)[[2]]$name, "Andy")
})
test_that("union on mixed serialization types correctly returns a byte RRDD", {
@@ -614,14 +614,14 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", {
unionByte <- unionRDD(rdd, dfRDD)
expect_is(unionByte, "RDD")
expect_equal(getSerializedMode(unionByte), "byte")
- expect_equal(collect(unionByte)[[1]], 1)
- expect_equal(collect(unionByte)[[12]]$name, "Andy")
+ expect_equal(collectRDD(unionByte)[[1]], 1)
+ expect_equal(collectRDD(unionByte)[[12]]$name, "Andy")
unionString <- unionRDD(textRDD, dfRDD)
expect_is(unionString, "RDD")
expect_equal(getSerializedMode(unionString), "byte")
- expect_equal(collect(unionString)[[1]], "Michael")
- expect_equal(collect(unionString)[[5]]$name, "Andy")
+ expect_equal(collectRDD(unionString)[[1]], "Michael")
+ expect_equal(collectRDD(unionString)[[5]]$name, "Andy")
})
test_that("objectFile() works with row serialization", {
@@ -633,7 +633,7 @@ test_that("objectFile() works with row serialization", {
expect_is(objectIn, "RDD")
expect_equal(getSerializedMode(objectIn), "byte")
- expect_equal(collect(objectIn)[[2]]$age, 30)
+ expect_equal(collectRDD(objectIn)[[2]]$age, 30)
})
test_that("lapply() on a DataFrame returns an RDD with the correct columns", {
@@ -643,7 +643,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", {
row
})
expect_is(testRDD, "RDD")
- collected <- collect(testRDD)
+ collected <- collectRDD(testRDD)
expect_equal(collected[[1]]$name, "Michael")
expect_equal(collected[[2]]$newCol, 35)
})
@@ -715,10 +715,10 @@ test_that("multiple pipeline transformations result in an RDD with the correct v
row
})
expect_is(second, "RDD")
- expect_equal(count(second), 3)
- expect_equal(collect(second)[[2]]$age, 35)
- expect_true(collect(second)[[2]]$testCol)
- expect_false(collect(second)[[3]]$testCol)
+ expect_equal(countRDD(second), 3)
+ expect_equal(collectRDD(second)[[2]]$age, 35)
+ expect_true(collectRDD(second)[[2]]$testCol)
+ expect_false(collectRDD(second)[[3]]$testCol)
})
test_that("cache(), persist(), and unpersist() on a DataFrame", {
@@ -1608,7 +1608,7 @@ test_that("toJSON() returns an RDD of the correct values", {
testRDD <- toJSON(df)
expect_is(testRDD, "RDD")
expect_equal(getSerializedMode(testRDD), "string")
- expect_equal(collect(testRDD)[[1]], mockLines[1])
+ expect_equal(collectRDD(testRDD)[[1]], mockLines[1])
})
test_that("showDF()", {
diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R
index daf5e41abe13f..dcf479363b9a8 100644
--- a/R/pkg/inst/tests/testthat/test_take.R
+++ b/R/pkg/inst/tests/testthat/test_take.R
@@ -36,32 +36,32 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext",
test_that("take() gives back the original elements in correct count and order", {
numVectorRDD <- parallelize(sc, numVector, 10)
# case: number of elements to take is less than the size of the first partition
- expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1)))
+ expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1)))
# case: number of elements to take is the same as the size of the first partition
- expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11)))
+ expect_equal(takeRDD(numVectorRDD, 11), as.list(head(numVector, n = 11)))
# case: number of elements to take is greater than all elements
- expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector))
- expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector))
+ expect_equal(takeRDD(numVectorRDD, length(numVector)), as.list(numVector))
+ expect_equal(takeRDD(numVectorRDD, length(numVector) + 1), as.list(numVector))
numListRDD <- parallelize(sc, numList, 1)
numListRDD2 <- parallelize(sc, numList, 4)
- expect_equal(take(numListRDD, 3), take(numListRDD2, 3))
- expect_equal(take(numListRDD, 5), take(numListRDD2, 5))
- expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1)))
- expect_equal(take(numListRDD2, 999), numList)
+ expect_equal(takeRDD(numListRDD, 3), takeRDD(numListRDD2, 3))
+ expect_equal(takeRDD(numListRDD, 5), takeRDD(numListRDD2, 5))
+ expect_equal(takeRDD(numListRDD, 1), as.list(head(numList, n = 1)))
+ expect_equal(takeRDD(numListRDD2, 999), numList)
strVectorRDD <- parallelize(sc, strVector, 2)
strVectorRDD2 <- parallelize(sc, strVector, 3)
- expect_equal(take(strVectorRDD, 4), as.list(strVector))
- expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2)))
+ expect_equal(takeRDD(strVectorRDD, 4), as.list(strVector))
+ expect_equal(takeRDD(strVectorRDD2, 2), as.list(head(strVector, n = 2)))
strListRDD <- parallelize(sc, strList, 4)
strListRDD2 <- parallelize(sc, strList, 1)
- expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3)))
- expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1)))
+ expect_equal(takeRDD(strListRDD, 3), as.list(head(strList, n = 3)))
+ expect_equal(takeRDD(strListRDD2, 1), as.list(head(strList, n = 1)))
- expect_equal(length(take(strListRDD, 0)), 0)
- expect_equal(length(take(strVectorRDD, 0)), 0)
- expect_equal(length(take(numListRDD, 0)), 0)
- expect_equal(length(take(numVectorRDD, 0)), 0)
+ expect_equal(length(takeRDD(strListRDD, 0)), 0)
+ expect_equal(length(takeRDD(strVectorRDD, 0)), 0)
+ expect_equal(length(takeRDD(numListRDD, 0)), 0)
+ expect_equal(length(takeRDD(numVectorRDD, 0)), 0)
})
diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R
index 7b2cc74753fe2..ba434a5d4127b 100644
--- a/R/pkg/inst/tests/testthat/test_textFile.R
+++ b/R/pkg/inst/tests/testthat/test_textFile.R
@@ -29,8 +29,8 @@ test_that("textFile() on a local file returns an RDD", {
rdd <- textFile(sc, fileName)
expect_is(rdd, "RDD")
- expect_true(count(rdd) > 0)
- expect_equal(count(rdd), 2)
+ expect_true(countRDD(rdd) > 0)
+ expect_equal(countRDD(rdd), 2)
unlink(fileName)
})
@@ -40,7 +40,7 @@ test_that("textFile() followed by a collect() returns the same content", {
writeLines(mockFile, fileName)
rdd <- textFile(sc, fileName)
- expect_equal(collect(rdd), as.list(mockFile))
+ expect_equal(collectRDD(rdd), as.list(mockFile))
unlink(fileName)
})
@@ -55,7 +55,7 @@ test_that("textFile() word count works as expected", {
wordCount <- lapply(words, function(word) { list(word, 1L) })
counts <- reduceByKey(wordCount, "+", 2L)
- output <- collect(counts)
+ output <- collectRDD(counts)
expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1),
list("Spark", 2))
expect_equal(sortKeyValueList(output), sortKeyValueList(expected))
@@ -72,7 +72,7 @@ test_that("several transformations on RDD created by textFile()", {
# PipelinedRDD initially created from RDD
rdd <- lapply(rdd, function(x) paste(x, x))
}
- collect(rdd)
+ collectRDD(rdd)
unlink(fileName)
})
@@ -85,7 +85,7 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content",
rdd <- textFile(sc, fileName1, 1L)
saveAsTextFile(rdd, fileName2)
rdd <- textFile(sc, fileName2)
- expect_equal(collect(rdd), as.list(mockFile))
+ expect_equal(collectRDD(rdd), as.list(mockFile))
unlink(fileName1)
unlink(fileName2)
@@ -97,7 +97,7 @@ test_that("saveAsTextFile() on a parallelized list works as expected", {
rdd <- parallelize(sc, l, 1L)
saveAsTextFile(rdd, fileName)
rdd <- textFile(sc, fileName)
- expect_equal(collect(rdd), lapply(l, function(x) {toString(x)}))
+ expect_equal(collectRDD(rdd), lapply(l, function(x) {toString(x)}))
unlink(fileName)
})
@@ -117,7 +117,7 @@ test_that("textFile() and saveAsTextFile() word count works as expected", {
saveAsTextFile(counts, fileName2)
rdd <- textFile(sc, fileName2)
- output <- collect(rdd)
+ output <- collectRDD(rdd)
expected <- list(list("awesome.", 1), list("Spark", 2),
list("pretty.", 1), list("is", 2))
expectedStr <- lapply(expected, function(x) { toString(x) })
@@ -134,7 +134,7 @@ test_that("textFile() on multiple paths", {
writeLines("Spark is awesome.", fileName2)
rdd <- textFile(sc, c(fileName1, fileName2))
- expect_equal(count(rdd), 2)
+ expect_equal(countRDD(rdd), 2)
unlink(fileName1)
unlink(fileName2)
@@ -147,16 +147,16 @@ test_that("Pipelined operations on RDDs created using textFile", {
rdd <- textFile(sc, fileName)
lengths <- lapply(rdd, function(x) { length(x) })
- expect_equal(collect(lengths), list(1, 1))
+ expect_equal(collectRDD(lengths), list(1, 1))
lengthsPipelined <- lapply(lengths, function(x) { x + 10 })
- expect_equal(collect(lengthsPipelined), list(11, 11))
+ expect_equal(collectRDD(lengthsPipelined), list(11, 11))
lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 })
- expect_equal(collect(lengths30), list(31, 31))
+ expect_equal(collectRDD(lengths30), list(31, 31))
lengths20 <- lapply(lengths, function(x) { x + 20 })
- expect_equal(collect(lengths20), list(21, 21))
+ expect_equal(collectRDD(lengths20), list(21, 21))
unlink(fileName)
})
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R
index 21a119a06b937..42783fde5f25b 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -24,7 +24,7 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext",
test_that("convertJListToRList() gives back (deserializes) the original JLists
of strings and integers", {
# It's hard to manually create a Java List using rJava, since it does not
- # support generics well. Instead, we rely on collect() returning a
+ # support generics well. Instead, we rely on collectRDD() returning a
# JList.
nums <- as.list(1:10)
rdd <- parallelize(sc, nums, 1L)
@@ -48,7 +48,7 @@ test_that("serializeToBytes on RDD", {
text.rdd <- textFile(sc, fileName)
expect_equal(getSerializedMode(text.rdd), "string")
ser.rdd <- serializeToBytes(text.rdd)
- expect_equal(collect(ser.rdd), as.list(mockFile))
+ expect_equal(collectRDD(ser.rdd), as.list(mockFile))
expect_equal(getSerializedMode(ser.rdd), "byte")
unlink(fileName)
@@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", {
env <- environment(newF)
expect_equal(ls(env), "t")
expect_equal(get("t", envir = env, inherits = FALSE), t)
- actual <- collect(lapply(rdd, f))
+ actual <- collectRDD(lapply(rdd, f))
expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6)))
expect_equal(actual, expected)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 37bfcdfdf4777..097728c821570 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -22,6 +22,4 @@ private[master] object ApplicationState extends Enumeration {
type ApplicationState = Value
val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value
-
- val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index fded8475a0916..dfffc47703ab4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -58,6 +58,7 @@ private[deploy] class Master(
private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
private val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
+ private val MAX_EXECUTOR_RETRIES = conf.getInt("spark.deploy.maxExecutorRetries", 10)
val workers = new HashSet[WorkerInfo]
val idToApp = new HashMap[String, ApplicationInfo]
@@ -265,7 +266,11 @@ private[deploy] class Master(
val normalExit = exitStatus == Some(0)
// Only retry certain number of times so we don't go into an infinite loop.
- if (!normalExit && appInfo.incrementRetryCount() >= ApplicationState.MAX_NUM_RETRY) {
+ // Important note: this code path is not exercised by tests, so be very careful when
+ // changing this `if` condition.
+ if (!normalExit
+ && appInfo.incrementRetryCount() >= MAX_EXECUTOR_RETRIES
+ && MAX_EXECUTOR_RETRIES >= 0) { // < 0 disables this application-killing path
val execs = appInfo.executors.values
if (!execs.exists(_.state == ExecutorState.RUNNING)) {
logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " +
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 5bb505bf09f17..dd149a919fe55 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -225,6 +225,15 @@ class TaskMetrics private[spark] () extends Serializable {
}
private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums
+
+ /**
+ * Looks for a registered accumulator by accumulator name.
+ */
+ private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
+ accumulators.find { acc =>
+ acc.name.isDefined && acc.name.get == name
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index 99e6d39583747..2dcd67c7b89fa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -552,7 +552,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
taskId: String,
reason: String): Unit = {
stateLock.synchronized {
- removeExecutor(taskId, SlaveLost(reason))
+ // Do not call removeExecutor() after this scheduler backend was stopped because
+ // removeExecutor() internally will send a message to the driver endpoint but
+ // the driver endpoint is not available now, otherwise an exception will be thrown.
+ if (!stopCalled) {
+ removeExecutor(taskId, SlaveLost(reason))
+ }
slaves(slaveId).taskIDs.remove(taskId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index a9167ce6edf90..d130a37db5b5d 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -23,6 +23,8 @@ import java.util.ArrayList
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
+import scala.collection.JavaConverters._
+
import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext}
import org.apache.spark.scheduler.AccumulableInfo
@@ -257,6 +259,16 @@ private[spark] object AccumulatorContext {
originals.clear()
}
+ /**
+ * Looks for a registered accumulator by accumulator name.
+ */
+ private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
+ originals.values().asScala.find { ref =>
+ val acc = ref.get
+ acc != null && acc.name.isDefined && acc.name.get == name
+ }.map(_.get)
+ }
+
// Identifier for distinguishing SQL metrics from other accumulators
private[spark] val SQL_ACCUM_IDENTIFIER = "sql"
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
index 7f21d4c623afc..f6ec167a187de 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
@@ -21,6 +21,7 @@ import java.util.Collections
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
import org.apache.mesos.{Protos, Scheduler, SchedulerDriver}
import org.apache.mesos.Protos._
@@ -34,6 +35,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
import org.apache.spark.scheduler.TaskSchedulerImpl
class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
@@ -47,6 +49,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
private var backend: MesosCoarseGrainedSchedulerBackend = _
private var externalShuffleClient: MesosExternalShuffleClient = _
private var driverEndpoint: RpcEndpointRef = _
+ @volatile private var stopCalled = false
test("mesos supports killing and limiting executors") {
setBackend()
@@ -252,6 +255,32 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
backend.start()
}
+ test("Do not call removeExecutor() after backend is stopped") {
+ setBackend()
+
+ // launches a task on a valid offer
+ val offers = List((backend.executorMemory(sc), 1))
+ offerResources(offers)
+ verifyTaskLaunched("o1")
+
+ // launches a thread simulating status update
+ val statusUpdateThread = new Thread {
+ override def run(): Unit = {
+ while (!stopCalled) {
+ Thread.sleep(100)
+ }
+
+ val status = createTaskStatus("0", "s1", TaskState.TASK_FINISHED)
+ backend.statusUpdate(driver, status)
+ }
+ }.start
+
+ backend.stop()
+ // Any method of the backend involving sending messages to the driver endpoint should not
+ // be called after the backend is stopped.
+ verify(driverEndpoint, never()).askWithRetry(isA(classOf[RemoveExecutor]))(any[ClassTag[_]])
+ }
+
private def verifyDeclinedOffer(driver: SchedulerDriver,
offerId: OfferID,
filter: Boolean = false): Unit = {
@@ -350,6 +379,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
mesosDriver = newDriver
}
+ override def stopExecutors(): Unit = {
+ stopCalled = true
+ }
+
markRegistered()
}
backend.start()
diff --git a/docs/configuration.md b/docs/configuration.md
index 8facd0ecf367a..500a6dad113da 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1449,8 +1449,10 @@ Apart from these, the following properties are also available, and may be useful
the properties must be overwritten in the protocol-specific namespace.
Use spark.ssl.YYY.XXX settings to overwrite the global configuration for
- particular protocol denoted by YYY. Currently YYY can be
- only fs for file server.
+ particular protocol denoted by YYY. Example values for YYY
+ include fs, ui, standalone, and
+ historyServer. See SSL
+ Configuration for details on hierarchical SSL configuration for services.
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index bf4b968eb8b78..07b38d9cc9a8f 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -24,7 +24,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT
[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED]
[Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A]
[EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext
-[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A]
[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]]
[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]]
[RDD Persistence]: programming-guide.html#rdd-persistence
@@ -596,7 +595,7 @@ compute the average age of the more senior followers of each user.
### Map Reduce Triplets Transition Guide (Legacy)
In earlier versions of GraphX neighborhood aggregation was accomplished using the
-[`mapReduceTriplets`][Graph.mapReduceTriplets] operator:
+`mapReduceTriplets` operator:
{% highlight scala %}
class Graph[VD, ED] {
@@ -607,7 +606,7 @@ class Graph[VD, ED] {
}
{% endhighlight %}
-The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which
+The `mapReduceTriplets` operator takes a user defined map function which
is applied to each triplet and can yield *messages* which are aggregated using the user defined
`reduce` function.
However, we found the user of the returned iterator to be expensive and it inhibited our ability to
diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md
index f5804fdeee5aa..12a03d3c91984 100644
--- a/docs/ml-advanced.md
+++ b/docs/ml-advanced.md
@@ -49,7 +49,7 @@ MLlib L-BFGS solver calls the corresponding implementation in [breeze](https://g
## Normal equation solver for weighted least squares
-MLlib implements normal equation solver for [weighted least squares](https://en.wikipedia.org/wiki/Least_squares#Weighted_least_squares) by [WeightedLeastSquares](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala).
+MLlib implements normal equation solver for [weighted least squares](https://en.wikipedia.org/wiki/Least_squares#Weighted_least_squares) by [WeightedLeastSquares]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala).
Given $n$ weighted observations $(w_i, a_i, b_i)$:
@@ -73,7 +73,7 @@ In order to make the normal equation approach efficient, WeightedLeastSquares re
## Iteratively reweighted least squares (IRLS)
-MLlib implements [iteratively reweighted least squares (IRLS)](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares) by [IterativelyReweightedLeastSquares](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala).
+MLlib implements [iteratively reweighted least squares (IRLS)](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares) by [IterativelyReweightedLeastSquares]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala).
It can be used to find the maximum likelihood estimates of a generalized linear model (GLM), find M-estimator in robust regression and other optimization problems.
Refer to [Iteratively Reweighted Least Squares for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives](http://www.jstor.org/stable/2345503) for more information.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index c864c9030835e..5ae63fe4e6e07 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -195,6 +195,21 @@ SPARK_MASTER_OPTS supports the following system properties:
the whole cluster by default.
+
+ spark.deploy.maxExecutorRetries |
+ 10 |
+
+ Limit on the maximum number of back-to-back executor failures that can occur before the
+ standalone cluster manager removes a faulty application. An application will never be removed
+ if it has any running executors. If an application experiences more than
+ spark.deploy.maxExecutorRetries failures in a row, no executors
+ successfully start running in between those failures, and the application has no running
+ executors then the standalone cluster manager will remove the application and mark it as failed.
+ To disable this automatic removal, set spark.deploy.maxExecutorRetries to
+ -1.
+
+ |
+
spark.worker.timeout |
60 |
diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md
index 479140f519103..f52bf348fcc99 100644
--- a/docs/streaming-custom-receivers.md
+++ b/docs/streaming-custom-receivers.md
@@ -181,7 +181,7 @@ val words = lines.flatMap(_.split(" "))
...
{% endhighlight %}
-The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
+The full source code is in the example [CustomReceiver.scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
@@ -193,7 +193,7 @@ JavaDStream words = lines.flatMap(new FlatMapFunction()
...
{% endhighlight %}
-The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java).
+The full source code is in the example [JavaCustomReceiver.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java).
diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md
index 8eeeee75dbf40..767e1f9402e01 100644
--- a/docs/streaming-flume-integration.md
+++ b/docs/streaming-flume-integration.md
@@ -63,7 +63,7 @@ configuring Flume agents.
By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type.
See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py).
diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md
index da4a845fe2d41..f8f7b95cf7458 100644
--- a/docs/streaming-kafka-0-8-integration.md
+++ b/docs/streaming-kafka-0-8-integration.md
@@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume])
You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*;
@@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]);
You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
@@ -49,7 +49,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume])
By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/kafka_wordcount.py).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/kafka_wordcount.py).
@@ -106,7 +106,7 @@ Next, we discuss how to use this approach in your streaming application.
You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type.
See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*;
@@ -118,7 +118,7 @@ Next, we discuss how to use this approach in your streaming application.
You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type.
See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
@@ -127,7 +127,7 @@ Next, we discuss how to use this approach in your streaming application.
You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type.
By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/direct_kafka_wordcount.py).
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 3d40b2c3136eb..14e17443e362c 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -126,7 +126,7 @@ ssc.awaitTermination() // Wait for the computation to terminate
{% endhighlight %}
The complete code can be found in the Spark Streaming example
-[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala).
+[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala).
@@ -216,7 +216,7 @@ jssc.awaitTermination(); // Wait for the computation to terminate
{% endhighlight %}
The complete code can be found in the Spark Streaming example
-[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
+[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
@@ -277,7 +277,7 @@ ssc.awaitTermination() # Wait for the computation to terminate
{% endhighlight %}
The complete code can be found in the Spark Streaming example
-[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py).
+[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/network_wordcount.py).
@@ -854,7 +854,7 @@ JavaPairDStream runningCounts = pairs.updateStateByKey(updateFu
The update function will be called for each word, with `newValues` having a sequence of 1's (from
the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
Java code, take a look at the example
-[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming
+[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming
/JavaStatefulNetworkWordCount.java).
@@ -877,7 +877,7 @@ runningCounts = pairs.updateStateByKey(updateFunction)
The update function will be called for each word, with `newValues` having a sequence of 1's (from
the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
Python code, take a look at the example
-[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py).
+[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/stateful_network_wordcount.py).
@@ -1428,7 +1428,7 @@ wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) =>
{% endhighlight %}
-See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala).
+See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala).
{% highlight java %}
@@ -1491,7 +1491,7 @@ wordCounts.foreachRDD(new Function2, Time, Void>()
{% endhighlight %}
-See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java).
+See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java).
{% highlight python %}
@@ -1526,7 +1526,7 @@ wordCounts.foreachRDD(echo)
{% endhighlight %}
-See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/recoverable_network_wordcount.py).
+See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py).
@@ -1564,7 +1564,7 @@ words.foreachRDD { rdd =>
{% endhighlight %}
-See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala).
+See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala).
{% highlight java %}
@@ -1619,7 +1619,7 @@ words.foreachRDD(
);
{% endhighlight %}
-See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java).
+See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java).
{% highlight python %}
@@ -1661,7 +1661,7 @@ def process(time, rdd):
words.foreachRDD(process)
{% endhighlight %}
-See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/sql_network_wordcount.py).
+See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/sql_network_wordcount.py).
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 8c14c3d220a23..811e8c408cb45 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -14,9 +14,9 @@ Structured Streaming is a scalable and fault-tolerant stream processing engine b
# Quick Example
Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in
-[Scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/
-[Java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/
-[Python]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you
+[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/
+[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/
+[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you
[download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
@@ -618,9 +618,9 @@ The result tables would look something like the following.

Since this windowing is similar to grouping, in code, you can use `groupBy()` and `window()` operations to express windowed aggregations. You can see the full code for the below examples in
-[Scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/
-[Java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/
-[Python]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py).
+[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/
+[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/
+[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py).
diff --git a/docs/tuning.md b/docs/tuning.md
index 1ed14091c0546..976f2eb8a7b23 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -115,7 +115,7 @@ Although there are two relevant configurations, the typical user should not need
as the default values are applicable to most workloads:
* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB)
-(default 0.6). The rest of the space (25%) is reserved for user data structures, internal
+(default 0.6). The rest of the space (40%) is reserved for user data structures, internal
metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually
large records.
* `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5).
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
index 26965612cc0ab..f9776fc63686e 100644
--- a/external/kafka-0-10/pom.xml
+++ b/external/kafka-0-10/pom.xml
@@ -51,7 +51,7 @@
org.apache.kafka
kafka_${scala.binary.version}
- 0.10.0.0
+ 0.10.0.1
com.sun.jmx
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index edaafb912c5c5..b17e198077949 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.kafka
import java.io.OutputStream
-import java.lang.{Integer => JInt, Long => JLong}
+import java.lang.{Integer => JInt, Long => JLong, Number => JNumber}
import java.nio.charset.StandardCharsets
import java.util.{List => JList, Map => JMap, Set => JSet}
@@ -682,7 +682,7 @@ private[kafka] class KafkaUtilsPythonHelper {
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
- fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = {
+ fromOffsets: JMap[TopicAndPartition, JNumber]): JavaDStream[(Array[Byte], Array[Byte])] = {
val messageHandler =
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler))
@@ -692,7 +692,7 @@ private[kafka] class KafkaUtilsPythonHelper {
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
- fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = {
+ fromOffsets: JMap[TopicAndPartition, JNumber]): JavaDStream[Array[Byte]] = {
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler).
@@ -704,7 +704,7 @@ private[kafka] class KafkaUtilsPythonHelper {
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
- fromOffsets: JMap[TopicAndPartition, JLong],
+ fromOffsets: JMap[TopicAndPartition, JNumber],
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = {
val currentFromOffsets = if (!fromOffsets.isEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 72fb35bd79ad7..6e872c1f2cada 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -19,6 +19,8 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
+import org.apache.commons.math3.util.CombinatoricsUtils
+
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.linalg._
@@ -84,12 +86,12 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
@Since("1.6.0")
object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
- private def choose(n: Int, k: Int): Int = {
- Range(n, n - k, -1).product / Range(k, 1, -1).product
+ private def getPolySize(numFeatures: Int, degree: Int): Int = {
+ val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree)
+ require(n <= Integer.MAX_VALUE)
+ n.toInt
}
- private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
-
private def expandDense(
values: Array[Double],
lastIdx: Int,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index 8e1f9ddb36cbe..9ecd321b128f6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -116,5 +116,29 @@ class PolynomialExpansionSuite
.setDegree(3)
testDefaultReadWrite(t)
}
+
+ test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") {
+ val data: Array[(Vector, Int, Int)] = Array(
+ (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367),
+ (Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367),
+ (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375)
+ )
+
+ val df = spark.createDataFrame(data)
+ .toDF("features", "expectedPoly10size", "expectedPoly11size")
+
+ val t = new PolynomialExpansion()
+ .setInputCol("features")
+ .setOutputCol("polyFeatures")
+
+ for (i <- Seq(10, 11)) {
+ val transformed = t.setDegree(i)
+ .transform(df)
+ .select(s"expectedPoly${i}size", "polyFeatures")
+ .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
+
+ assert(transformed.collect.forall(identity))
+ }
+ }
}
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8a01805ec831b..4ea83e24bbc9a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1440,11 +1440,15 @@ def split(str, pattern):
@ignore_unicode_prefix
@since(1.5)
def regexp_extract(str, pattern, idx):
- """Extract a specific(idx) group identified by a java regex, from the specified string column.
+ """Extract a specific group matched by a Java regex, from the specified string column.
+ If the regex did not match, or the specified group did not match, an empty string is returned.
>>> df = spark.createDataFrame([('100-200',)], ['str'])
>>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
[Row(d=u'100')]
+ >>> df = spark.createDataFrame([('foo',)], ['str'])
+ >>> df.select(regexp_extract('str', '(\d+)', 1).alias('d')).collect()
+ [Row(d=u'')]
>>> df = spark.createDataFrame([('aaaac',)], ['str'])
>>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()
[Row(d=u'')]
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 4020bb3fa45b0..64de33e8ec0a8 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -401,8 +401,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
:param numPartitions: the number of partitions
:param predicates: a list of expressions suitable for inclusion in WHERE clauses;
each one defines one partition of the :class:`DataFrame`
- :param properties: a dictionary of JDBC database connection arguments; normally,
- at least a "user" and "password" property should be included
+ :param properties: a dictionary of JDBC database connection arguments. Normally at
+ least properties "user" and "password" with their corresponding values.
+ For example { 'user' : 'SYSTEM', 'password' : 'mypassword' }
:return: a DataFrame
"""
if properties is None:
@@ -716,9 +717,9 @@ def jdbc(self, url, table, mode=None, properties=None):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
- :param properties: JDBC database connection arguments, a list of
- arbitrary string tag/value. Normally at least a
- "user" and "password" property should be included.
+ :param properties: a dictionary of JDBC database connection arguments. Normally at
+ least properties "user" and "password" with their corresponding values.
+ For example { 'user' : 'SYSTEM', 'password' : 'mypassword' }
"""
if properties is None:
properties = dict()
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 2c1a667fc80c4..bf27d8047a753 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -287,6 +287,9 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)
+ def __hash__(self):
+ return (self._topic, self._partition).__hash__()
+
class Broker(object):
"""
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 360ba1e7167cb..5ac007cd598b9 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -41,6 +41,9 @@
else:
import unittest
+if sys.version >= "3":
+ long = int
+
from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.context import StreamingContext
@@ -1058,7 +1061,6 @@ def test_kafka_direct_stream(self):
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
self._validateStreamResult(sendData, stream)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_from_offset(self):
"""Test the Python direct Kafka stream API with start offset specified."""
topic = self._randomTopic()
@@ -1072,7 +1074,6 @@ def test_kafka_direct_stream_from_offset(self):
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
self._validateStreamResult(sendData, stream)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd(self):
"""Test the Python direct Kafka RDD API."""
topic = self._randomTopic()
@@ -1085,7 +1086,6 @@ def test_kafka_rdd(self):
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
self._validateRddResult(sendData, rdd)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd_with_leaders(self):
"""Test the Python direct Kafka RDD API with leaders."""
topic = self._randomTopic()
@@ -1100,7 +1100,6 @@ def test_kafka_rdd_with_leaders(self):
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
self._validateRddResult(sendData, rdd)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd_get_offsetRanges(self):
"""Test Python direct Kafka RDD get OffsetRanges."""
topic = self._randomTopic()
@@ -1113,7 +1112,6 @@ def test_kafka_rdd_get_offsetRanges(self):
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
self.assertEqual(offsetRanges, rdd.offsetRanges())
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_foreach_get_offsetRanges(self):
"""Test the Python direct Kafka stream foreachRDD get offsetRanges."""
topic = self._randomTopic()
@@ -1138,7 +1136,6 @@ def getOffsetRanges(_, rdd):
self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_transform_get_offsetRanges(self):
"""Test the Python direct Kafka stream transform get offsetRanges."""
topic = self._randomTopic()
@@ -1176,7 +1173,6 @@ def test_topic_and_partition_equality(self):
self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_transform_with_checkpoint(self):
"""Test the Python direct Kafka stream transform with checkpoint correctly recovered."""
topic = self._randomTopic()
@@ -1225,7 +1221,6 @@ def setup():
finally:
shutil.rmtree(tmpdir)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd_message_handler(self):
"""Test Python direct Kafka RDD MessageHandler."""
topic = self._randomTopic()
@@ -1242,7 +1237,6 @@ def getKeyAndDoubleMessage(m):
messageHandler=getKeyAndDoubleMessage)
self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd)
- @unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_message_handler(self):
"""Test the Python direct Kafka stream MessageHandler."""
topic = self._randomTopic()
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index de98a871b3358..aca728234ad99 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -84,6 +84,7 @@ statement
| ALTER VIEW tableIdentifier
DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions
| ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation
+ | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions
| DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable
| DROP VIEW (IF EXISTS)? tableIdentifier #dropTable
| CREATE (OR REPLACE)? TEMPORARY? VIEW (IF NOT EXISTS)? tableIdentifier
@@ -121,6 +122,7 @@ statement
| LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE
tableIdentifier partitionSpec? #loadData
| TRUNCATE TABLE tableIdentifier partitionSpec? #truncateTable
+ | MSCK REPAIR TABLE tableIdentifier #repairTable
| op=(ADD | LIST) identifier .*? #manageResource
| SET ROLE .*? #failNativeCommand
| SET .*? #setConfiguration
@@ -154,7 +156,6 @@ unsupportedHiveNativeCommands
| kw1=UNLOCK kw2=DATABASE
| kw1=CREATE kw2=TEMPORARY kw3=MACRO
| kw1=DROP kw2=TEMPORARY kw3=MACRO
- | kw1=MSCK kw2=REPAIR kw3=TABLE
| kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED
| kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY
| kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED
@@ -617,13 +618,13 @@ quotedIdentifier
;
number
- : DECIMAL_VALUE #decimalLiteral
- | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral
- | INTEGER_VALUE #integerLiteral
- | BIGINT_LITERAL #bigIntLiteral
- | SMALLINT_LITERAL #smallIntLiteral
- | TINYINT_LITERAL #tinyIntLiteral
- | DOUBLE_LITERAL #doubleLiteral
+ : MINUS? DECIMAL_VALUE #decimalLiteral
+ | MINUS? SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral
+ | MINUS? INTEGER_VALUE #integerLiteral
+ | MINUS? BIGINT_LITERAL #bigIntLiteral
+ | MINUS? SMALLINT_LITERAL #smallIntLiteral
+ | MINUS? TINYINT_LITERAL #tinyIntLiteral
+ | MINUS? DOUBLE_LITERAL #doubleLiteral
;
nonReserved
@@ -646,7 +647,7 @@ nonReserved
| CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT
| DBPROPERTIES | DFS | TRUNCATE | COMPUTE | LIST
| STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER
- | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
+ | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
| ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH
| ASC | DESC | LIMIT | RENAME | SETS
| AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE
@@ -859,6 +860,7 @@ LOCK: 'LOCK';
UNLOCK: 'UNLOCK';
MSCK: 'MSCK';
REPAIR: 'REPAIR';
+RECOVER: 'RECOVER';
EXPORT: 'EXPORT';
IMPORT: 'IMPORT';
LOAD: 'LOAD';
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index d83eef7a41629..e16850efbea5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -463,6 +463,6 @@ trait Row extends Serializable {
* @throws NullPointerException when value is null.
*/
private def getAnyValAs[T <: AnyVal](i: Int): T =
- if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null")
+ if (isNullAt(i)) throw new NullPointerException(s"Value at index $i is null")
else getAs[T](i)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 660f523698e7f..57c3d9aece80c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -547,8 +547,7 @@ class Analyzer(
case a: Aggregate if containsStar(a.aggregateExpressions) =>
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
failAnalysis(
- "Group by position: star is not allowed to use in the select list " +
- "when using ordinals in group by")
+ "Star (*) is not allowed in select list when GROUP BY ordinal position is used")
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
@@ -723,9 +722,9 @@ class Analyzer(
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction)
} else {
- throw new UnresolvedException(s,
- s"Order/sort By position: $index does not exist " +
- s"The Select List is indexed from 1 to ${child.output.size}")
+ s.failAnalysis(
+ s"ORDER BY position $index is not in select list " +
+ s"(valid range is [1, ${child.output.size}])")
}
case o => o
}
@@ -737,17 +736,18 @@ class Analyzer(
if conf.groupByOrdinal && aggs.forall(_.resolved) &&
groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
val newGroups = groups.map {
- case IntegerIndex(index) if index > 0 && index <= aggs.size =>
+ case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size =>
aggs(index - 1) match {
case e if ResolveAggregateFunctions.containsAggregate(e) =>
- throw new UnresolvedException(a,
- s"Group by position: the '$index'th column in the select contains an " +
- s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY")
+ ordinal.failAnalysis(
+ s"GROUP BY position $index is an aggregate function, and " +
+ "aggregate functions are not allowed in GROUP BY")
case o => o
}
- case IntegerIndex(index) =>
- throw new UnresolvedException(a,
- s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.")
+ case ordinal @ IntegerIndex(index) =>
+ ordinal.failAnalysis(
+ s"GROUP BY position $index is not in select list " +
+ s"(valid range is [1, ${aggs.size}])")
case o => o
}
Aggregate(newGroups, aggs, child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 021952e7166f9..21e96aaf53844 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -543,11 +543,14 @@ object TypeCoercion {
// Decimal and Double remain the same
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
- case Divide(left, right) if isNumeric(left) && isNumeric(right) =>
+ case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}
- private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType]
+ private def isNumericOrNull(ex: Expression): Boolean = {
+ // We need to handle null types in case a query contains null literals.
+ ex.dataType.isInstanceOf[NumericType] || ex.dataType == NullType
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 7ff8795d4f05e..fa459aa2e5d72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
}
}
- override def sql: String = s"(-${child.sql})"
+ override def sql: String = s"(- ${child.sql})"
}
@ExpressionDescription(
@@ -75,7 +75,7 @@ case class UnaryPositive(child: Expression)
protected override def nullSafeEval(input: Any): Any = input
- override def sql: String = s"(+${child.sql})"
+ override def sql: String = s"(+ ${child.sql})"
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 9621db1d38762..37ec1a63394cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -232,27 +232,47 @@ case class NewInstance(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
- val argGen = arguments.map(_.genCode(ctx))
- val argString = argGen.map(_.value).mkString(", ")
+ val argIsNulls = ctx.freshName("argIsNulls")
+ ctx.addMutableState("boolean[]", argIsNulls,
+ s"$argIsNulls = new boolean[${arguments.size}];")
+ val argValues = arguments.zipWithIndex.map { case (e, i) =>
+ val argValue = ctx.freshName("argValue")
+ ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
+ argValue
+ }
+
+ val argCodes = arguments.zipWithIndex.map { case (e, i) =>
+ val expr = e.genCode(ctx)
+ expr.code + s"""
+ $argIsNulls[$i] = ${expr.isNull};
+ ${argValues(i)} = ${expr.value};
+ """
+ }
+ val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
var isNull = ev.isNull
val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
+ s"""
+ boolean $isNull = false;
+ for (int idx = 0; idx < ${arguments.length}; idx++) {
+ if ($argIsNulls[idx]) { $isNull = true; break; }
+ }
+ """
} else {
isNull = "false"
""
}
val constructorCall = outer.map { gen =>
- s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
+ s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})"""
}.getOrElse {
- s"new $className($argString)"
+ s"new $className(${argValues.mkString(", ")})"
}
val code = s"""
- ${argGen.map(_.code).mkString("\n")}
+ $argCode
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f24f8b78d476f..627f82994f8db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -695,6 +695,19 @@ object FoldablePropagation extends Rule[LogicalPlan] {
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
stop = true
j
+
+ // These 3 operators take attributes as constructor parameters, and these attributes
+ // can't be replaced by alias.
+ case m: MapGroups =>
+ stop = true
+ m
+ case f: FlatMapGroupsInR =>
+ stop = true
+ f
+ case c: CoGroup =>
+ stop = true
+ c
+
case p: LogicalPlan if !stop => p.transformExpressions {
case a: AttributeReference if foldableMap.contains(a) =>
foldableMap(a)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 1a0e7ab32a6c3..aee8eb1f3877b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -132,7 +132,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Build the insert clauses.
val inserts = ctx.multiInsertQueryBody.asScala.map {
body =>
- assert(body.querySpecification.fromClause == null,
+ validate(body.querySpecification.fromClause == null,
"Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
body)
@@ -591,7 +591,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// function takes X PERCENT as the input and the range of X is [0, 100], we need to
// adjust the fraction.
val eps = RandomSampler.roundingEpsilon
- assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
s"Sampling fraction ($fraction) must be on interval [0, 1]",
ctx)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
@@ -659,7 +659,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Get the backing expressions.
val expressions = ctx.expression.asScala.map { eCtx =>
val e = expression(eCtx)
- assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
+ validate(e.foldable, "All expressions in an inline table must be constants.", eCtx)
e
}
@@ -681,7 +681,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val baseAttributes = structType.toAttributes.map(_.withNullability(true))
val attributes = if (ctx.identifierList != null) {
val aliases = visitIdentifierList(ctx.identifierList)
- assert(aliases.size == baseAttributes.size,
+ validate(aliases.size == baseAttributes.size,
"Number of aliases must match the number of fields in an inline table.", ctx)
baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
} else {
@@ -1089,7 +1089,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// We currently only allow foldable integers.
def value: Int = {
val e = expression(ctx.expression)
- assert(e.resolved && e.foldable && e.dataType == IntegerType,
+ validate(e.resolved && e.foldable && e.dataType == IntegerType,
"Frame bound value must be a constant integer.",
ctx)
e.eval().asInstanceOf[Int]
@@ -1342,7 +1342,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
val intervals = ctx.intervalField.asScala.map(visitIntervalField)
- assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ validate(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
Literal(intervals.reduce(_.add(_)))
}
@@ -1369,7 +1369,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case (from, Some(t)) =>
throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
}
- assert(interval != null, "No interval can be constructed", ctx)
+ validate(interval != null, "No interval can be constructed", ctx)
interval
} catch {
// Handle Exceptions thrown by CalendarInterval
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index b04ce58e233aa..bc35ae2f55409 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -77,8 +77,8 @@ object ParserUtils {
Origin(Option(token.getLine), Option(token.getCharPositionInLine))
}
- /** Assert if a condition holds. If it doesn't throw a parse exception. */
- def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ /** Validate the condition. If it doesn't throw a parse exception. */
+ def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
if (!f) {
throw new ParseException(message, ctx)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index a13c45fe2ffee..9560563a8ca56 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.sql.Timestamp
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion}
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -730,6 +730,13 @@ class TypeCoercionSuite extends PlanTest {
// the right expression to Decimal.
ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
}
+
+ test("SPARK-17117 null type coercion in divide") {
+ val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+ val nullLit = Literal.create(null, NullType)
+ ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
+ ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index 0d624d17f4cd7..b903aeeb5c9ea 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -31,6 +31,8 @@
import java.util.Map;
import java.util.Set;
+import scala.Option;
+
import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups;
import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER;
import static org.apache.parquet.format.converter.ParquetMetadataConverter.range;
@@ -59,8 +61,12 @@
import org.apache.parquet.hadoop.util.ConfigurationUtil;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Types;
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskContext$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
+import org.apache.spark.util.AccumulatorV2;
+import org.apache.spark.util.LongAccumulator;
/**
* Base class for custom RecordReaders for Parquet that directly materialize to `T`.
@@ -144,6 +150,18 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
for (BlockMetaData block : blocks) {
this.totalRowCount += block.getRowCount();
}
+
+ // For test purpose.
+ // If the predefined accumulator exists, the row group number to read will be updated
+ // to the accumulator. So we can check if the row groups are filtered or not in test case.
+ TaskContext taskContext = TaskContext$.MODULE$.get();
+ if (taskContext != null) {
+ Option> accu = (Option>) taskContext.taskMetrics()
+ .lookForAccumulatorByName("numRowGroups");
+ if (accu.isDefined()) {
+ ((LongAccumulator)accu.get()).add((long)blocks.size());
+ }
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index a6867a67eeade..8eec42aab4fa3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,10 +21,11 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.expressions.ReduceAggregator
/**
* :: Experimental ::
@@ -177,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
- val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
-
- implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
- flatMapGroups(func)
+ val vEncoder = encoderFor[V]
+ val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
+ agg(aggregator)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 946d8cbc6bf4a..c88206c81a04e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -822,16 +822,19 @@ object SparkSession {
// No active nor global default session. Create a new one.
val sparkContext = userSuppliedContext.getOrElse {
// set app name if not given
- if (!options.contains("spark.app.name")) {
- options += "spark.app.name" -> java.util.UUID.randomUUID().toString
- }
-
+ val randomAppName = java.util.UUID.randomUUID().toString
val sparkConf = new SparkConf()
options.foreach { case (k, v) => sparkConf.set(k, v) }
+ if (!sparkConf.contains("spark.app.name")) {
+ sparkConf.setAppName(randomAppName)
+ }
val sc = SparkContext.getOrCreate(sparkConf)
// maybe this is an existing SparkContext, update its SparkConf which maybe used
// by SparkSession
options.foreach { case (k, v) => sc.conf.set(k, v) }
+ if (!sc.conf.contains("spark.app.name")) {
+ sc.conf.setAppName(randomAppName)
+ }
sc
}
session = new SparkSession(sparkContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index de2503a87ab7d..83b7c779ab818 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
/** Holds a cached logical plan and its data */
-private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
+case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
/**
* Provides support in a SQLContext for caching query results and automatically using these cached
@@ -41,7 +41,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe
*
* Internal to Spark SQL.
*/
-private[sql] class CacheManager extends Logging {
+class CacheManager extends Logging {
@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
@@ -68,13 +68,13 @@ private[sql] class CacheManager extends Logging {
}
/** Clears all cached tables. */
- private[sql] def clearCache(): Unit = writeLock {
+ def clearCache(): Unit = writeLock {
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}
/** Checks if the cache is empty. */
- private[sql] def isEmpty: Boolean = readLock {
+ def isEmpty: Boolean = readLock {
cachedData.isEmpty
}
@@ -83,7 +83,7 @@ private[sql] class CacheManager extends Logging {
* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
* recomputing the in-memory columnar representation of the underlying table is expensive.
*/
- private[sql] def cacheQuery(
+ def cacheQuery(
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
@@ -108,7 +108,7 @@ private[sql] class CacheManager extends Logging {
* Tries to remove the data for the given [[Dataset]] from the cache.
* No operation, if it's already uncached.
*/
- private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
+ def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
val found = dataIndex >= 0
@@ -120,17 +120,17 @@ private[sql] class CacheManager extends Logging {
}
/** Optionally returns cached data for the given [[Dataset]] */
- private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
+ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
}
/** Optionally returns cached data for the given [[LogicalPlan]]. */
- private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
+ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
cachedData.find(cd => plan.sameResult(cd.plan))
}
/** Replaces segments of the given logical plan with cached versions where possible. */
- private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
+ def useCachedData(plan: LogicalPlan): LogicalPlan = {
plan transformDown {
case currentFragment =>
lookupCachedData(currentFragment)
@@ -143,7 +143,7 @@ private[sql] class CacheManager extends Logging {
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
* function will over invalidate.
*/
- private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock {
+ def invalidateCache(plan: LogicalPlan): Unit = writeLock {
cachedData.foreach {
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
data.cachedRepresentation.recache()
@@ -155,7 +155,7 @@ private[sql] class CacheManager extends Logging {
* Invalidates the cache of any data that contains `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
- private[sql] def invalidateCachedPath(
+ def invalidateCachedPath(
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 09203e69983da..ba30bed0b450e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -75,7 +75,7 @@ object RDDConversions {
}
/** Logical plan node for scanning data from an RDD. */
-private[sql] case class LogicalRDD(
+case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow])(session: SparkSession)
extends LogicalPlan with MultiInstanceRelation {
@@ -106,12 +106,12 @@ private[sql] case class LogicalRDD(
}
/** Physical plan node for scanning data from an RDD. */
-private[sql] case class RDDScanExec(
+case class RDDScanExec(
output: Seq[Attribute],
rdd: RDD[InternalRow],
override val nodeName: String) extends LeafExecNode {
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
@@ -130,7 +130,7 @@ private[sql] case class RDDScanExec(
}
}
-private[sql] trait DataSourceScanExec extends LeafExecNode {
+trait DataSourceScanExec extends LeafExecNode {
val rdd: RDD[InternalRow]
val relation: BaseRelation
val metastoreTableIdentifier: Option[TableIdentifier]
@@ -147,7 +147,7 @@ private[sql] trait DataSourceScanExec extends LeafExecNode {
}
/** Physical plan node for scanning data from a relation. */
-private[sql] case class RowDataSourceScanExec(
+case class RowDataSourceScanExec(
output: Seq[Attribute],
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
@@ -156,7 +156,7 @@ private[sql] case class RowDataSourceScanExec(
override val metastoreTableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with CodegenSupport {
- private[sql] override lazy val metrics =
+ override lazy val metrics =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
val outputUnsafeRows = relation match {
@@ -222,7 +222,7 @@ private[sql] case class RowDataSourceScanExec(
}
/** Physical plan node for scanning data from a batched relation. */
-private[sql] case class BatchedDataSourceScanExec(
+case class BatchedDataSourceScanExec(
output: Seq[Attribute],
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
@@ -231,7 +231,7 @@ private[sql] case class BatchedDataSourceScanExec(
override val metastoreTableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with CodegenSupport {
- private[sql] override lazy val metrics =
+ override lazy val metrics =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
@@ -337,7 +337,7 @@ private[sql] case class BatchedDataSourceScanExec(
}
}
-private[sql] object DataSourceScanExec {
+object DataSourceScanExec {
// Metadata keys
val INPUT_PATHS = "InputPaths"
val PUSHED_FILTERS = "PushedFilters"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 4c046f7bdca48..d5603b3b00914 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -39,7 +39,7 @@ case class ExpandExec(
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// The GroupExpressions can output data with arbitrary partitioning, so set it
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala
index 7a2a9eed5807d..a299fed7fd14a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala
@@ -22,7 +22,7 @@ package org.apache.spark.sql.execution
* the list of paths that it returns will be returned to a user who calls `inputPaths` on any
* DataFrame that queries this relation.
*/
-private[sql] trait FileRelation {
+trait FileRelation {
/** Returns the list of files that will be read when scanning this relation. */
def inputFiles: Array[String]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 8b62c5507c0c8..39189a2b0c72c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -55,7 +55,7 @@ case class GenerateExec(
child: SparkPlan)
extends UnaryExecNode {
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = AttributeSet(output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index df2f238d8c2e0..9f53a99346caa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -26,11 +26,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Physical plan node for scanning data from a local collection.
*/
-private[sql] case class LocalTableScanExec(
+case class LocalTableScanExec(
output: Seq[Attribute],
rows: Seq[InternalRow]) extends LeafExecNode {
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
private val unsafeRows: Array[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
index 7462dbc4eba3a..717ff93eab5d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
* iterator to consume the next row, whereas RowIterator combines these calls into a single
* [[advanceNext()]] method.
*/
-private[sql] abstract class RowIterator {
+abstract class RowIterator {
/**
* Advance this iterator by a single row. Returns `false` if this iterator has no more rows
* and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 6cb1a44a2044a..ec07aab359ac6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
SparkListenerSQLExecutionStart}
-private[sql] object SQLExecution {
+object SQLExecution {
val EXECUTION_ID_KEY = "spark.sql.execution.id"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 66a16ac576b3a..cde3ed48ffeaf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -22,11 +22,9 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.unsafe.sort.RadixSort;
/**
* Performs (external) sorting.
@@ -52,7 +50,7 @@ case class SortExec(
private val enableRadixSort = sqlContext.conf.enableRadixSort
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"),
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 045ccc7bd6eae..79cb40948b982 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -72,24 +72,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return all metadata that describes more details of this SparkPlan.
*/
- private[sql] def metadata: Map[String, String] = Map.empty
+ def metadata: Map[String, String] = Map.empty
/**
* Return all metrics containing metrics of this SparkPlan.
*/
- private[sql] def metrics: Map[String, SQLMetric] = Map.empty
+ def metrics: Map[String, SQLMetric] = Map.empty
/**
* Reset all the metrics.
*/
- private[sql] def resetMetrics(): Unit = {
+ def resetMetrics(): Unit = {
metrics.valuesIterator.foreach(_.reset())
}
/**
* Return a LongSQLMetric according to the name.
*/
- private[sql] def longMetric(name: String): SQLMetric = metrics(name)
+ def longMetric(name: String): SQLMetric = metrics(name)
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
@@ -395,7 +395,7 @@ object SparkPlan {
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}
-private[sql] trait LeafExecNode extends SparkPlan {
+trait LeafExecNode extends SparkPlan {
override def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
}
@@ -407,7 +407,7 @@ object UnaryExecNode {
}
}
-private[sql] trait UnaryExecNode extends SparkPlan {
+trait UnaryExecNode extends SparkPlan {
def child: SparkPlan
override def children: Seq[SparkPlan] = child :: Nil
@@ -415,7 +415,7 @@ private[sql] trait UnaryExecNode extends SparkPlan {
override def outputPartitioning: Partitioning = child.outputPartitioning
}
-private[sql] trait BinaryExecNode extends SparkPlan {
+trait BinaryExecNode extends SparkPlan {
def left: SparkPlan
def right: SparkPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index f84070a0c4bcb..7aa93126fdabd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -47,7 +47,7 @@ class SparkPlanInfo(
}
}
-private[sql] object SparkPlanInfo {
+private[execution] object SparkPlanInfo {
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
val children = plan match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index f3cd9236e086a..876b3341d217e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution
import scala.collection.JavaConverters._
-import scala.util.Try
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.TerminalNode
@@ -405,6 +404,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))
}
+ /**
+ * Create a [[AlterTableRecoverPartitionsCommand]] command.
+ *
+ * For example:
+ * {{{
+ * MSCK REPAIR TABLE tablename
+ * }}}
+ */
+ override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableRecoverPartitionsCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ "MSCK REPAIR TABLE")
+ }
+
/**
* Convert a table property list into a key-value map.
* This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]].
@@ -763,6 +776,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
ctx.EXISTS != null)
}
+ /**
+ * Create an [[AlterTableRecoverPartitionsCommand]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table RECOVER PARTITIONS;
+ * }}}
+ */
+ override def visitRecoverPartitions(
+ ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableRecoverPartitionsCommand(visitTableIdentifier(ctx.tableIdentifier))
+ }
+
/**
* Create an [[AlterTableSetLocationCommand]] command
*
@@ -1152,7 +1178,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
entry("mapkey.delim", ctx.keysTerminatedBy) ++
Option(ctx.linesSeparatedBy).toSeq.map { token =>
val value = string(token)
- assert(
+ validate(
value == "\n",
s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
ctx)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b619d4edc30de..6d7c193fd42c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
@@ -43,13 +42,12 @@ import org.apache.spark.sql.streaming.StreamingQuery
* writing libraries should instead consider using the stable APIs provided in
* [[org.apache.spark.sql.sources]]
*/
-@DeveloperApi
abstract class SparkStrategy extends GenericStrategy[SparkPlan] {
override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan)
}
-private[sql] case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
+case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
override def output: Seq[Attribute] = plan.output
@@ -58,7 +56,7 @@ private[sql] case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
}
}
-private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
+abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SparkPlanner =>
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 484923428f4ad..8ab553369de6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -40,12 +40,12 @@ import org.apache.spark.unsafe.Platform
*
* @param numFields the number of fields in the row being serialized.
*/
-private[sql] class UnsafeRowSerializer(
+class UnsafeRowSerializer(
numFields: Int,
dataSize: SQLMetric = null) extends Serializer with Serializable {
override def newInstance(): SerializerInstance =
new UnsafeRowSerializerInstance(numFields, dataSize)
- override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
+ override def supportsRelocationOfSerializedObjects: Boolean = true
}
private class UnsafeRowSerializerInstance(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index ac4c3aae5f8ee..fb57ed7692de4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -295,7 +295,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegenExec.PIPELINE_DURATION_METRIC))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 54d7340d8acd0..cfc47aba889aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -54,7 +54,7 @@ case class HashAggregateExec(
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 05dbacf07a178..7c41e5e4c28aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -49,7 +49,7 @@ case class SortAggregateExec(
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index b047bc0641dd2..586e1456ac69e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -204,7 +204,7 @@ sealed trait BufferSetterGetterUtils {
/**
* A Mutable [[Row]] representing a mutable aggregation buffer.
*/
-private[sql] class MutableAggregationBufferImpl (
+private[aggregate] class MutableAggregationBufferImpl(
schema: StructType,
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
@@ -266,7 +266,7 @@ private[sql] class MutableAggregationBufferImpl (
/**
* A [[Row]] representing an immutable aggregation buffer.
*/
-private[sql] class InputAggregationBuffer private[sql] (
+private[aggregate] class InputAggregationBuffer(
schema: StructType,
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
@@ -319,7 +319,7 @@ private[sql] class InputAggregationBuffer private[sql] (
* The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
* internal aggregation code path.
*/
-private[sql] case class ScalaUDAF(
+case class ScalaUDAF(
children: Seq[Expression],
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 185c79f899e68..e6f7081f2916d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -102,7 +102,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
}
}
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -228,7 +228,7 @@ case class SampleExec(
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
@@ -317,7 +317,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
override val output: Seq[Attribute] = range.output
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// output attributes should not affect the results
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 079e122a5a85a..479934a7afc75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.CollectionAccumulator
-private[sql] object InMemoryRelation {
+object InMemoryRelation {
def apply(
useCompression: Boolean,
batchSize: Int,
@@ -55,15 +55,15 @@ private[sql] object InMemoryRelation {
private[columnar]
case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
-private[sql] case class InMemoryRelation(
+case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
@transient child: SparkPlan,
tableName: Option[String])(
- @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
- private[sql] val batchStats: CollectionAccumulator[InternalRow] =
+ @transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
+ val batchStats: CollectionAccumulator[InternalRow] =
child.sqlContext.sparkContext.collectionAccumulator[InternalRow])
extends logical.LeafNode with MultiInstanceRelation {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 183e4947b6d72..e63b313cb1d5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.UserDefinedType
-private[sql] case class InMemoryTableScanExec(
+case class InMemoryTableScanExec(
attributes: Seq[Attribute],
predicates: Seq[Expression],
@transient relation: InMemoryRelation)
@@ -36,7 +36,7 @@ private[sql] case class InMemoryTableScanExec(
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = attributes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 7eaad81a81615..cce1489abd301 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
*/
-private[sql] trait RunnableCommand extends LogicalPlan with logical.Command {
+trait RunnableCommand extends LogicalPlan with logical.Command {
override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
def run(sparkSession: SparkSession): Seq[Row]
@@ -45,7 +45,7 @@ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command {
* A physical operator that executes the run method of a `RunnableCommand` and
* saves the result to prevent multiple executions.
*/
-private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
+case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
/**
* A concrete command should override this lazy field to wrap up any side effects caused by the
* command or any other computation that should be evaluated exactly once. The value of this field
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 226f61ef404ae..aac70e90b883c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -17,19 +17,24 @@
package org.apache.spark.sql.execution.command
+import scala.collection.GenSeq
+import scala.collection.parallel.ForkJoinTaskSupport
+import scala.concurrent.forkjoin.ForkJoinPool
import scala.util.control.NonFatal
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
+
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable}
-import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.execution.datasources.BucketSpec
+import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.types._
-
// Note: The definition of these commands are based on the ones described in
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
@@ -424,6 +429,111 @@ case class AlterTableDropPartitionCommand(
}
+/**
+ * Recover Partitions in ALTER TABLE: recover all the partition in the directory of a table and
+ * update the catalog.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table RECOVER PARTITIONS;
+ * MSCK REPAIR TABLE table;
+ * }}}
+ */
+case class AlterTableRecoverPartitionsCommand(
+ tableName: TableIdentifier,
+ cmd: String = "ALTER TABLE RECOVER PARTITIONS") extends RunnableCommand {
+ override def run(spark: SparkSession): Seq[Row] = {
+ val catalog = spark.sessionState.catalog
+ if (!catalog.tableExists(tableName)) {
+ throw new AnalysisException(s"Table $tableName in $cmd does not exist.")
+ }
+ val table = catalog.getTableMetadata(tableName)
+ if (catalog.isTemporaryTable(tableName)) {
+ throw new AnalysisException(
+ s"Operation not allowed: $cmd on temporary tables: $tableName")
+ }
+ if (DDLUtils.isDatasourceTable(table)) {
+ throw new AnalysisException(
+ s"Operation not allowed: $cmd on datasource tables: $tableName")
+ }
+ if (table.tableType != CatalogTableType.EXTERNAL) {
+ throw new AnalysisException(
+ s"Operation not allowed: $cmd only works on external tables: $tableName")
+ }
+ if (!DDLUtils.isTablePartitioned(table)) {
+ throw new AnalysisException(
+ s"Operation not allowed: $cmd only works on partitioned tables: $tableName")
+ }
+ if (table.storage.locationUri.isEmpty) {
+ throw new AnalysisException(
+ s"Operation not allowed: $cmd only works on table with location provided: $tableName")
+ }
+
+ val root = new Path(table.storage.locationUri.get)
+ val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration)
+ // Dummy jobconf to get to the pathFilter defined in configuration
+ // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow)
+ val jobConf = new JobConf(spark.sparkContext.hadoopConfiguration, this.getClass)
+ val pathFilter = FileInputFormat.getInputPathFilter(jobConf)
+ val partitionSpecsAndLocs = scanPartitions(
+ spark, fs, pathFilter, root, Map(), table.partitionColumnNames.map(_.toLowerCase))
+ val parts = partitionSpecsAndLocs.map { case (spec, location) =>
+ // inherit table storage format (possibly except for location)
+ CatalogTablePartition(spec, table.storage.copy(locationUri = Some(location.toUri.toString)))
+ }
+ spark.sessionState.catalog.createPartitions(tableName,
+ parts.toArray[CatalogTablePartition], ignoreIfExists = true)
+ Seq.empty[Row]
+ }
+
+ @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8))
+
+ private def scanPartitions(
+ spark: SparkSession,
+ fs: FileSystem,
+ filter: PathFilter,
+ path: Path,
+ spec: TablePartitionSpec,
+ partitionNames: Seq[String]): GenSeq[(TablePartitionSpec, Path)] = {
+ if (partitionNames.length == 0) {
+ return Seq(spec -> path)
+ }
+
+ val statuses = fs.listStatus(path)
+ val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt
+ val statusPar: GenSeq[FileStatus] =
+ if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) {
+ val parArray = statuses.par
+ parArray.tasksupport = evalTaskSupport
+ parArray
+ } else {
+ statuses
+ }
+ statusPar.flatMap { st =>
+ val name = st.getPath.getName
+ if (st.isDirectory && name.contains("=")) {
+ val ps = name.split("=", 2)
+ val columnName = PartitioningUtils.unescapePathName(ps(0)).toLowerCase
+ // TODO: Validate the value
+ val value = PartitioningUtils.unescapePathName(ps(1))
+ // comparing with case-insensitive, but preserve the case
+ if (columnName == partitionNames(0)) {
+ scanPartitions(
+ spark, fs, filter, st.getPath, spec ++ Map(columnName -> value), partitionNames.drop(1))
+ } else {
+ logWarning(s"expect partition column ${partitionNames(0)}, but got ${ps(0)}, ignore it")
+ Seq()
+ }
+ } else {
+ if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) {
+ logWarning(s"ignore ${new Path(path, name)}")
+ }
+ Seq()
+ }
+ }
+ }
+}
+
/**
* A command that sets the location of a table or a partition.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index f572b93991e0c..f5727da387d13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources
-import java.util.ServiceLoader
+import java.util.{ServiceConfigurationError, ServiceLoader}
import scala.collection.JavaConverters._
import scala.language.{existentials, implicitConversions}
@@ -123,50 +123,63 @@ case class DataSource(
val loader = Utils.getContextOrSparkClassLoader
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
- serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match {
- // the provider format did not match any given registered aliases
- case Nil =>
- try {
- Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match {
- case Success(dataSource) =>
- // Found the data source using fully qualified path
- dataSource
- case Failure(error) =>
- if (provider.toLowerCase == "orc" ||
+ try {
+ serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match {
+ // the provider format did not match any given registered aliases
+ case Nil =>
+ try {
+ Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match {
+ case Success(dataSource) =>
+ // Found the data source using fully qualified path
+ dataSource
+ case Failure(error) =>
+ if (provider.toLowerCase == "orc" ||
provider.startsWith("org.apache.spark.sql.hive.orc")) {
- throw new AnalysisException(
- "The ORC data source must be used with Hive support enabled")
- } else if (provider.toLowerCase == "avro" ||
+ throw new AnalysisException(
+ "The ORC data source must be used with Hive support enabled")
+ } else if (provider.toLowerCase == "avro" ||
provider == "com.databricks.spark.avro") {
- throw new AnalysisException(
- s"Failed to find data source: ${provider.toLowerCase}. Please use Spark " +
- "package http://spark-packages.org/package/databricks/spark-avro")
+ throw new AnalysisException(
+ s"Failed to find data source: ${provider.toLowerCase}. Please use Spark " +
+ "package http://spark-packages.org/package/databricks/spark-avro")
+ } else {
+ throw new ClassNotFoundException(
+ s"Failed to find data source: $provider. Please find packages at " +
+ "http://spark-packages.org",
+ error)
+ }
+ }
+ } catch {
+ case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal
+ // NoClassDefFoundError's class name uses "/" rather than "." for packages
+ val className = e.getMessage.replaceAll("/", ".")
+ if (spark2RemovedClasses.contains(className)) {
+ throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " +
+ "Please check if your library is compatible with Spark 2.0", e)
} else {
- throw new ClassNotFoundException(
- s"Failed to find data source: $provider. Please find packages at " +
- "http://spark-packages.org",
- error)
+ throw e
}
}
- } catch {
- case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal
- // NoClassDefFoundError's class name uses "/" rather than "." for packages
- val className = e.getMessage.replaceAll("/", ".")
- if (spark2RemovedClasses.contains(className)) {
- throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " +
- "Please check if your library is compatible with Spark 2.0", e)
- } else {
- throw e
- }
+ case head :: Nil =>
+ // there is exactly one registered alias
+ head.getClass
+ case sources =>
+ // There are multiple registered aliases for the input
+ sys.error(s"Multiple sources found for $provider " +
+ s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
+ "please specify the fully qualified class name.")
+ }
+ } catch {
+ case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] =>
+ // NoClassDefFoundError's class name uses "/" rather than "." for packages
+ val className = e.getCause.getMessage.replaceAll("/", ".")
+ if (spark2RemovedClasses.contains(className)) {
+ throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " +
+ "Please remove the incompatible library from classpath or upgrade it. " +
+ s"Error: ${e.getMessage}", e)
+ } else {
+ throw e
}
- case head :: Nil =>
- // there is exactly one registered alias
- head.getClass
- case sources =>
- // There are multiple registered aliases for the input
- sys.error(s"Multiple sources found for $provider " +
- s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
- "please specify the fully qualified class name.")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 27133f0a43f2e..bd65d0251197b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -43,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String
* Replaces generic operations with specific variants that are designed to work with Spark
* SQL Data Sources.
*/
-private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
+case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
def resolver: Resolver = {
if (conf.caseSensitiveAnalysis) {
@@ -54,7 +54,7 @@ private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[Logi
}
// The access modifier is used to expose this method to tests.
- private[sql] def convertStaticPartitions(
+ def convertStaticPartitions(
sourceAttributes: Seq[Attribute],
providedPartitions: Map[String, Option[String]],
targetAttributes: Seq[Attribute],
@@ -202,7 +202,7 @@ private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[Logi
* Replaces [[SimpleCatalogRelation]] with data source table if its table property contains data
* source information.
*/
-private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {
+class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {
private def readDataSourceTable(sparkSession: SparkSession, table: CatalogTable): LogicalPlan = {
val userSpecifiedSchema = DDLUtils.getSchemaFromTableProperties(table)
@@ -242,7 +242,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
-private[sql] object DataSourceStrategy extends Strategy with Logging {
+object DataSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
pruneFilterProjectRaw(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 8af9562330e81..74510f9c08b6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -54,7 +54,7 @@ import org.apache.spark.sql.execution.SparkPlan
* is under the threshold with the addition of the next file, add it. If not, open a new bucket
* and add it. Proceed to the next file.
*/
-private[sql] object FileSourceStrategy extends Strategy with Logging {
+object FileSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
l @ LogicalRelation(files: HadoopFsRelation, _, table)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
index 8549ae96e2f39..b2ff68a833fea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.InsertableRelation
/**
* Inserts the results of `query` in to a relation that extends [[InsertableRelation]].
*/
-private[sql] case class InsertIntoDataSourceCommand(
+case class InsertIntoDataSourceCommand(
logicalRelation: LogicalRelation,
query: LogicalPlan,
overwrite: Boolean)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 1426dcf4697ff..d8b8fae3bf2d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -55,7 +55,7 @@ import org.apache.spark.sql.internal.SQLConf
* 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
* thrown during job commitment, also aborts the job.
*/
-private[sql] case class InsertIntoHadoopFsRelationCommand(
+case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala
index 811e96c99a96d..cef9d4d9c7f1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala
@@ -204,6 +204,6 @@ abstract class PartitioningAwareFileCatalog(
private def isDataPath(path: Path): Boolean = {
val name = path.getName
- !(name.startsWith("_") || name.startsWith("."))
+ !((name.startsWith("_") && !name.contains("=")) || name.startsWith("."))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index c3561099d6842..504464216e5a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._
+// TODO: We should tighten up visibility of the classes here once we clean up Hive coupling.
object PartitionDirectory {
def apply(values: InternalRow, path: String): PartitionDirectory =
@@ -41,22 +42,23 @@ object PartitionDirectory {
* Holds a directory in a partitioned collection of files as well as as the partition values
* in the form of a Row. Before scanning, the files at `path` need to be enumerated.
*/
-private[sql] case class PartitionDirectory(values: InternalRow, path: Path)
+case class PartitionDirectory(values: InternalRow, path: Path)
-private[sql] case class PartitionSpec(
+case class PartitionSpec(
partitionColumns: StructType,
partitions: Seq[PartitionDirectory])
-private[sql] object PartitionSpec {
+object PartitionSpec {
val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory])
}
-private[sql] object PartitioningUtils {
+object PartitioningUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't
// depend on Hive.
- private[sql] val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__"
+ val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__"
- private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) {
+ private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal])
+ {
require(columnNames.size == literals.size)
}
@@ -83,7 +85,7 @@ private[sql] object PartitioningUtils {
* path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28")))
* }}}
*/
- private[sql] def parsePartitions(
+ private[datasources] def parsePartitions(
paths: Seq[Path],
defaultPartitionName: String,
typeInference: Boolean,
@@ -166,7 +168,7 @@ private[sql] object PartitioningUtils {
* hdfs://:/path/to/partition
* }}}
*/
- private[sql] def parsePartition(
+ private[datasources] def parsePartition(
path: Path,
defaultPartitionName: String,
typeInference: Boolean,
@@ -249,7 +251,7 @@ private[sql] object PartitioningUtils {
* DoubleType -> StringType
* }}}
*/
- private[sql] def resolvePartitions(
+ def resolvePartitions(
pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = {
if (pathsWithPartitionValues.isEmpty) {
Seq.empty
@@ -275,7 +277,7 @@ private[sql] object PartitioningUtils {
}
}
- private[sql] def listConflictingPartitionColumns(
+ private[datasources] def listConflictingPartitionColumns(
pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = {
val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct
@@ -308,7 +310,7 @@ private[sql] object PartitioningUtils {
* [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and
* [[StringType]].
*/
- private[sql] def inferPartitionColumnValue(
+ private[datasources] def inferPartitionColumnValue(
raw: String,
defaultPartitionName: String,
typeInference: Boolean): Literal = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 9a0b46c1a4a5e..e25924b1ba1ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -40,14 +40,14 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A container for all the details required when writing to a table. */
-case class WriteRelation(
+private[datasources] case class WriteRelation(
sparkSession: SparkSession,
dataSchema: StructType,
path: String,
prepareJobForWrite: Job => OutputWriterFactory,
bucketSpec: Option[BucketSpec])
-private[sql] abstract class BaseWriterContainer(
+private[datasources] abstract class BaseWriterContainer(
@transient val relation: WriteRelation,
@transient private val job: Job,
isAppend: Boolean)
@@ -234,7 +234,7 @@ private[sql] abstract class BaseWriterContainer(
/**
* A writer that writes all of the rows in a partition to a single file.
*/
-private[sql] class DefaultWriterContainer(
+private[datasources] class DefaultWriterContainer(
relation: WriteRelation,
job: Job,
isAppend: Boolean)
@@ -293,7 +293,7 @@ private[sql] class DefaultWriterContainer(
* done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the
* writer externally sorts the remaining rows and then writes out them out one file at a time.
*/
-private[sql] class DynamicPartitionWriterContainer(
+private[datasources] class DynamicPartitionWriterContainer(
relation: WriteRelation,
job: Job,
partitionColumns: Seq[Attribute],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
index 6008d73717f77..2bafe967993b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
@@ -31,7 +31,7 @@ private[sql] case class BucketSpec(
bucketColumnNames: Seq[String],
sortColumnNames: Seq[String])
-private[sql] object BucketingUtils {
+object BucketingUtils {
// The file name of bucketed data should have 3 parts:
// 1. some other information in the head of file name
// 2. bucket id part, some numbers, starts with "_"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 22fb8163b1c0a..10fe541a2c575 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -23,7 +23,7 @@ import java.text.SimpleDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes}
-private[sql] class CSVOptions(@transient private val parameters: Map[String, String])
+private[csv] class CSVOptions(@transient private val parameters: Map[String, String])
extends Logging with Serializable {
private def getChar(paramName: String, default: Char): Char = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index 7929ebbd90f71..0a996547d2536 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging
* @param params Parameters object
* @param headers headers for the columns
*/
-private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) {
+private[csv] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) {
protected lazy val parser: CsvParser = {
val settings = new CsvParserSettings()
@@ -60,7 +60,7 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String])
* @param params Parameters object for configuration
* @param headers headers for columns
*/
-private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
+private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 083ac3350ef02..10d84f4a70d5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -159,7 +159,7 @@ object CSVRelation extends Logging {
}
}
-private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
+private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
override def newInstance(
path: String,
bucketId: Option[Int],
@@ -170,7 +170,7 @@ private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit
}
}
-private[sql] class CsvOutputWriter(
+private[csv] class CsvOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 0b5a19fe9384b..4351bed99460a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -76,7 +76,7 @@ abstract class OutputWriterFactory extends Serializable {
* through the [[OutputWriterFactory]] implementation.
* @since 2.0.0
*/
- private[sql] def newWriter(path: String): OutputWriter = {
+ def newWriter(path: String): OutputWriter = {
throw new UnsupportedOperationException("newInstance with just path not supported")
}
}
@@ -249,7 +249,7 @@ trait FileFormat {
* appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]]
* returns.
*/
- private[sql] def buildReaderWithPartitionValues(
+ def buildReaderWithPartitionValues(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
@@ -356,14 +356,14 @@ trait FileCatalog {
/**
* Helper methods for gathering metadata from HDFS.
*/
-private[sql] object HadoopFsRelation extends Logging {
+object HadoopFsRelation extends Logging {
/** Checks if we should filter out this path name. */
def shouldFilterOut(pathName: String): Boolean = {
// We filter everything that starts with _ and ., except _common_metadata and _metadata
// because Parquet needs to find those metadata files from leaf files returned by this method.
// We should refactor this logic to not mix metadata files with data files.
- (pathName.startsWith("_") || pathName.startsWith(".")) &&
+ ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) &&
!pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 24e2c1a5fd2f6..f655155287974 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -38,11 +38,11 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Data corresponding to one partition of a JDBCRDD.
*/
-private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
+case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
override def index: Int = idx
}
-private[sql] object JDBCRDD extends Logging {
+object JDBCRDD extends Logging {
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
@@ -192,7 +192,7 @@ private[sql] object JDBCRDD extends Logging {
* Turns a single Filter into a String representing a SQL expression.
* Returns None for an unhandled filter.
*/
- private[jdbc] def compileFilter(f: Filter): Option[String] = {
+ def compileFilter(f: Filter): Option[String] = {
Option(f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case EqualNullSafe(attr, value) =>
@@ -275,7 +275,7 @@ private[sql] object JDBCRDD extends Logging {
* driver code and the workers must be able to access the database; the driver
* needs to fetch the schema while the workers need to fetch the data.
*/
-private[sql] class JDBCRDD(
+private[jdbc] class JDBCRDD(
sc: SparkContext,
getConnection: () => Connection,
schema: StructType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 86aef1f7d4411..c58de3ae6f9e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -55,7 +55,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val jsonFiles = files.filterNot { status =>
val name = status.getPath.getName
- name.startsWith("_") || name.startsWith(".")
+ (name.startsWith("_") && !name.contains("=")) || name.startsWith(".")
}.toArray
val jsonSchema = InferSchema.infer(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 5397d50b6c7a0..94980886c6265 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -46,12 +46,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
-private[sql] class ParquetFileFormat
+class ParquetFileFormat
extends FileFormat
with DataSourceRegister
with Logging
@@ -233,7 +234,8 @@ private[sql] class ParquetFileFormat
// Lists `FileStatus`es of all leaf nodes (files) under all base directories.
val leaves = allFiles.filter { f =>
isSummaryFile(f.getPath) ||
- !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith("."))
+ !((f.getPath.getName.startsWith("_") && !f.getPath.getName.contains("=")) ||
+ f.getPath.getName.startsWith("."))
}.toArray.sortBy(_.getPath.toString)
FileTypes(
@@ -266,7 +268,7 @@ private[sql] class ParquetFileFormat
true
}
- override private[sql] def buildReaderWithPartitionValues(
+ override def buildReaderWithPartitionValues(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
@@ -355,6 +357,11 @@ private[sql] class ParquetFileFormat
val hadoopAttemptContext =
new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
+ // Try to push down filters when filter push-down is enabled.
+ // Notice: This push-down is RowGroups level, not individual records.
+ if (pushed.isDefined) {
+ ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
+ }
val parquetReader = if (enableVectorizedReader) {
val vectorizedReader = new VectorizedParquetRecordReader()
vectorizedReader.initialize(split, hadoopAttemptContext)
@@ -422,7 +429,7 @@ private[sql] class ParquetFileFormat
* writes the data to the path used to generate the output writer. Callers of this factory
* has to ensure which files are to be considered as committed.
*/
-private[sql] class ParquetOutputWriterFactory(
+private[parquet] class ParquetOutputWriterFactory(
sqlConf: SQLConf,
dataSchema: StructType,
hadoopConf: Configuration,
@@ -471,7 +478,7 @@ private[sql] class ParquetOutputWriterFactory(
* Returns a [[OutputWriter]] that writes data to the give path without using
* [[OutputCommitter]].
*/
- override private[sql] def newWriter(path: String): OutputWriter = new OutputWriter {
+ override def newWriter(path: String): OutputWriter = new OutputWriter {
// Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter
private val hadoopTaskAttempId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0)
@@ -518,7 +525,7 @@ private[sql] class ParquetOutputWriterFactory(
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
-private[sql] class ParquetOutputWriter(
+private[parquet] class ParquetOutputWriter(
path: String,
bucketId: Option[Int],
context: TaskAttemptContext)
@@ -556,12 +563,13 @@ private[sql] class ParquetOutputWriter(
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
- override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
+ override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
override def close(): Unit = recordWriter.close(context)
}
-private[sql] object ParquetFileFormat extends Logging {
+
+object ParquetFileFormat extends Logging {
/**
* If parquet's block size (row group size) setting is larger than the min split size,
* we use parquet's block size setting as the min split size. Otherwise, we will create
@@ -708,7 +716,7 @@ private[sql] object ParquetFileFormat extends Logging {
* distinguish binary and string). This method generates a correct schema by merging Metastore
* schema data types and Parquet schema field names.
*/
- private[sql] def mergeMetastoreParquetSchema(
+ def mergeMetastoreParquetSchema(
metastoreSchema: StructType,
parquetSchema: StructType): StructType = {
def schemaConflictMessage: String =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 70ae829219d59..2edd2757428aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -28,7 +28,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
-private[sql] object ParquetFilters {
+object ParquetFilters {
case class SetInFilter[T <: Comparable[T]](
valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
index dd2e915e7b7f9..3eec582714e15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.SQLConf
/**
* Options for the Parquet data source.
*/
-private[sql] class ParquetOptions(
+private[parquet] class ParquetOptions(
@transient private val parameters: Map[String, String],
@transient private val sqlConf: SQLConf)
extends Serializable {
@@ -56,8 +56,8 @@ private[sql] class ParquetOptions(
}
-private[sql] object ParquetOptions {
- private[sql] val MERGE_SCHEMA = "mergeSchema"
+object ParquetOptions {
+ val MERGE_SCHEMA = "mergeSchema"
// The parquet compression short names
private val shortParquetCompressionCodecNames = Map(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 15b9d14bd73fe..05908d908fd20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
/**
* Try to replaces [[UnresolvedRelation]]s with [[ResolveDataSource]].
*/
-private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
+class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedRelation if u.tableIdentifier.database.isDefined =>
try {
@@ -67,7 +67,7 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo
* table. It also does data type casting and field renaming, to make sure that the columns to be
* inserted have the correct data type and fields have the correct names.
*/
-private[sql] case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
+case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
private def preprocess(
insert: InsertIntoTable,
tblName: String,
@@ -147,7 +147,7 @@ private[sql] case class PreprocessTableInsertion(conf: SQLConf) extends Rule[Log
/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
-private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
+case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
extends (LogicalPlan => Unit) {
def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index e89f792496d6a..082f97a8808fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -104,7 +104,7 @@ package object debug {
}
}
- private[sql] case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+ case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
def output: Seq[Attribute] = child.output
class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index bd0841db7e8ab..a809076de5419 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -38,7 +38,7 @@ case class BroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan) extends Exchange {
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
"collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"),
"buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index 2ea6ee38a932a..57da85fa84f99 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -79,7 +79,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan}
* - post-shuffle partition 1: pre-shuffle partition 2
* - post-shuffle partition 2: pre-shuffle partition 3 and 4
*/
-private[sql] class ExchangeCoordinator(
+class ExchangeCoordinator(
numExchanges: Int,
advisoryTargetPostShuffleInputSize: Long,
minNumPostShufflePartitions: Option[Int] = None)
@@ -112,7 +112,7 @@ private[sql] class ExchangeCoordinator(
* Estimates partition start indices for post-shuffle partitions based on
* mapOutputStatistics provided by all pre-shuffle stages.
*/
- private[sql] def estimatePartitionStartIndices(
+ def estimatePartitionStartIndices(
mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = {
// If we have mapOutputStatistics.length < numExchange, it is because we do not submit
// a stage when the number of partitions of this dependency is 0.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index afe0fbea73bd9..7a4a251370706 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -40,7 +40,7 @@ case class ShuffleExchange(
child: SparkPlan,
@transient coordinator: Option[ExchangeCoordinator]) extends Exchange {
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"))
override def nodeName: String = {
@@ -81,7 +81,8 @@ case class ShuffleExchange(
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
- private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
+ private[exchange] def prepareShuffleDependency()
+ : ShuffleDependency[Int, InternalRow, InternalRow] = {
ShuffleExchange.prepareShuffleDependency(
child.execute(), child.output, newPartitioning, serializer)
}
@@ -92,7 +93,7 @@ case class ShuffleExchange(
* partition start indices array. If this optional array is defined, the returned
* [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
*/
- private[sql] def preparePostShuffleRDD(
+ private[exchange] def preparePostShuffleRDD(
shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
// If an array of partition start indices is provided, we need to use this array
@@ -194,7 +195,7 @@ object ShuffleExchange {
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
- private[sql] def prepareShuffleDependency(
+ def prepareShuffleDependency(
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 7c194ab72643a..0f24baacd18d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -45,7 +45,7 @@ case class BroadcastHashJoinExec(
right: SparkPlan)
extends BinaryExecNode with HashJoin with CodegenSupport {
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def requiredChildDistribution: Seq[Distribution] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 4d43765f8fcd3..6a9965f1a24cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -37,7 +37,7 @@ case class BroadcastNestedLoopJoinExec(
condition: Option[Expression],
withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
/** BuildRight means the right relation <=> the broadcast relation. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 3a0b6efdfc910..c97fffe88b719 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -34,7 +34,6 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
* will be much faster than building the right partition for every row in left RDD, it also
* materialize the right RDD (in case of the right RDD is nondeterministic).
*/
-private[spark]
class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
@@ -78,7 +77,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
for (x <- rdd1.iterator(partition.s1, context);
y <- createIter()) yield (x, y)
CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
- resultIter, sorter.cleanupResources)
+ resultIter, sorter.cleanupResources())
}
}
@@ -89,7 +88,7 @@ case class CartesianProductExec(
condition: Option[Expression]) extends BinaryExecNode {
override def output: Seq[Attribute] = left.output ++ right.output
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doPrepare(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 0036f9aadc5d9..afb6e5e3dd235 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -39,7 +39,7 @@ case class ShuffledHashJoinExec(
right: SparkPlan)
extends BinaryExecNode with HashJoin {
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index fac6b8de8ed5e..5c9c1e6062f0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -40,7 +40,7 @@ case class SortMergeJoinExec(
left: SparkPlan,
right: SparkPlan) extends BinaryExecNode with CodegenSupport {
- override private[sql] lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 9817a56f499a5..15afa0b1a5391 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -55,17 +55,17 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
override def value: Long = _value
// Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
- private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
new AccumulableInfo(
id, name, update, value, true, true, Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
}
}
-private[sql] object SQLMetrics {
- private[sql] val SUM_METRIC = "sum"
- private[sql] val SIZE_METRIC = "size"
- private[sql] val TIMING_METRIC = "timing"
+object SQLMetrics {
+ private val SUM_METRIC = "sum"
+ private val SIZE_METRIC = "size"
+ private val TIMING_METRIC = "timing"
def createMetric(sc: SparkContext, name: String): SQLMetric = {
val acc = new SQLMetric(SUM_METRIC)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 829bcae6f95d4..16e44845d5283 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.SparkPlan
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
* grouping key, evaluate them after aggregate.
*/
-private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
+object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
/**
* Returns whether the expression could only be evaluated within aggregate.
@@ -90,7 +90,7 @@ private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
-private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
+object ExtractPythonUDFs extends Rule[SparkPlan] {
private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
index 70539da348b0e..d2178e971ec20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -21,12 +21,12 @@ import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.Row
-import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+import org.apache.spark.sql.types.StructType
/**
* A function wrapper that applies the given R function to each partition.
*/
-private[sql] case class MapPartitionsRWrapper(
+case class MapPartitionsRWrapper(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index b19344f04383f..b9dbfcf7734c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
-private[sql] object FrequentItems extends Logging {
+object FrequentItems extends Logging {
/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
@@ -79,7 +79,7 @@ private[sql] object FrequentItems extends Logging {
* than 1e-4.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
- private[sql] def singlePassFreqItems(
+ def singlePassFreqItems(
df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index ea58df70b3252..50eecb409830f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-private[sql] object StatFunctions extends Logging {
+object StatFunctions extends Logging {
import QuantileSummaries.Stats
@@ -337,7 +337,7 @@ private[sql] object StatFunctions extends Logging {
}
/** Calculate the Pearson Correlation Coefficient for the given columns */
- private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
+ def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols, "correlation")
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
}
@@ -407,13 +407,13 @@ private[sql] object StatFunctions extends Logging {
* @param cols the column names
* @return the covariance of the two columns.
*/
- private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+ def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols, "covariance")
counts.cov
}
/** Generate a table of frequencies for the elements of two columns. */
- private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
+ def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
if (counts.length == 1e6.toInt) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 7367c68d0a0e5..05294df2673dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.streaming.OutputMode
* A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]]
* plan incrementally. Possibly preserving state in between each execution.
*/
-class IncrementalExecution private[sql](
+class IncrementalExecution(
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
val outputMode: OutputMode,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index af2229a46bebb..4d05af0b60358 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -49,10 +49,10 @@ class StreamExecution(
override val id: Long,
override val name: String,
checkpointRoot: String,
- private[sql] val logicalPlan: LogicalPlan,
+ val logicalPlan: LogicalPlan,
val sink: Sink,
val trigger: Trigger,
- private[sql] val triggerClock: Clock,
+ val triggerClock: Clock,
val outputMode: OutputMode)
extends StreamingQuery with Logging {
@@ -74,7 +74,7 @@ class StreamExecution(
* input source.
*/
@volatile
- private[sql] var committedOffsets = new StreamProgress
+ var committedOffsets = new StreamProgress
/**
* Tracks the offsets that are available to be processed, but have not yet be committed to the
@@ -102,10 +102,10 @@ class StreamExecution(
private var state: State = INITIALIZED
@volatile
- private[sql] var lastExecution: QueryExecution = null
+ var lastExecution: QueryExecution = null
@volatile
- private[sql] var streamDeathCause: StreamingQueryException = null
+ var streamDeathCause: StreamingQueryException = null
/* Get the call site in the caller thread; will pass this into the micro batch thread */
private val callSite = Utils.getCallSite()
@@ -115,7 +115,7 @@ class StreamExecution(
* [[org.apache.spark.util.UninterruptibleThread]] to avoid potential deadlocks in using
* [[HDFSMetadataLog]]. See SPARK-14131 for more details.
*/
- private[sql] val microBatchThread =
+ val microBatchThread =
new UninterruptibleThread(s"stream execution thread for $name") {
override def run(): Unit = {
// To fix call site like "run at :0", we bridge the call site from the caller
@@ -131,8 +131,7 @@ class StreamExecution(
* processing is done. Thus, the Nth record in this log indicated data that is currently being
* processed and the N-1th entry indicates which offsets have been durably committed to the sink.
*/
- private[sql] val offsetLog =
- new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets"))
+ val offsetLog = new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets"))
/** Whether the query is currently active or not */
override def isActive: Boolean = state == ACTIVE
@@ -159,7 +158,7 @@ class StreamExecution(
* Starts the execution. This returns only after the thread has started and [[QueryStarted]] event
* has been posted to all the listeners.
*/
- private[sql] def start(): Unit = {
+ def start(): Unit = {
microBatchThread.setDaemon(true)
microBatchThread.start()
startLatch.await() // Wait until thread started and QueryStart event has been posted
@@ -218,10 +217,7 @@ class StreamExecution(
} finally {
state = TERMINATED
sparkSession.streams.notifyQueryTermination(StreamExecution.this)
- postEvent(new QueryTerminated(
- this.toInfo,
- exception.map(_.getMessage),
- exception.map(_.getStackTrace.toSeq).getOrElse(Nil)))
+ postEvent(new QueryTerminated(this.toInfo, exception.map(_.cause).map(Utils.exceptionString)))
terminationLatch.countDown()
}
}
@@ -518,7 +514,7 @@ class StreamExecution(
case object TERMINATED extends State
}
-private[sql] object StreamExecution {
+object StreamExecution {
private val _nextId = new AtomicLong(0)
def nextId: Long = _nextId.getAndIncrement()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
index 405a5f0387a7e..db0bd9e6bc6f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
@@ -26,7 +26,7 @@ class StreamProgress(
val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset])
extends scala.collection.immutable.Map[Source, Offset] {
- private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = {
+ def toCompositeOffset(source: Seq[Source]): CompositeOffset = {
CompositeOffset(source.map(get))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 066765324ac94..a67fdceb3cee6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -113,7 +113,7 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate
* the store is the active instance. Accordingly, it either keeps it loaded and performs
* maintenance, or unloads the store.
*/
-private[sql] object StateStore extends Logging {
+object StateStore extends Logging {
val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval"
val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index e418217238cca..d945d7aff2da4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -45,7 +45,7 @@ private object StopCoordinator
extends StateStoreCoordinatorMessage
/** Helper object used to create reference to [[StateStoreCoordinator]]. */
-private[sql] object StateStoreCoordinatorRef extends Logging {
+object StateStoreCoordinatorRef extends Logging {
private val endpointName = "StateStoreCoordinator"
@@ -77,7 +77,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging {
* Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of
* [[StateStore]]s across all the executors, and get their locations for job scheduling.
*/
-private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
+class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
private[state] def reportActiveInstance(
storeId: StateStoreId,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
index 4b4fa126b85f3..23fc0bd0bce13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
@@ -24,7 +24,7 @@ import scala.xml.Node
import org.apache.spark.internal.Logging
import org.apache.spark.ui.{UIUtils, WebUIPage}
-private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging {
+class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging {
private val listener = parent.listener
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 6e94791901762..60f13432d78d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -46,14 +46,14 @@ case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)])
extends SparkListenerEvent
-private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory {
+class SQLHistoryListenerFactory extends SparkHistoryListenerFactory {
override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = {
List(new SQLHistoryListener(conf, sparkUI))
}
}
-private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging {
+class SQLListener(conf: SparkConf) extends SparkListener with Logging {
private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000)
@@ -333,7 +333,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
/**
* A [[SQLListener]] for rendering the SQL UI in the history server.
*/
-private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
+class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
extends SQLListener(conf) {
private var sqlTabAttached = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala
index e8675ce749a2b..d0376af3e31ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.ui
import org.apache.spark.internal.Logging
import org.apache.spark.ui.{SparkUI, SparkUITab}
-private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI)
+class SQLTab(val listener: SQLListener, sparkUI: SparkUI)
extends SparkUITab(sparkUI, "SQL") with Logging {
val parent = sparkUI
@@ -32,6 +32,6 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI)
parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql")
}
-private[sql] object SQLTab {
+object SQLTab {
private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 8f5681bfc7cc6..4bb9d6fef4c1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.execution.{SparkPlanInfo, WholeStageCodegenExec}
-import org.apache.spark.sql.execution.metric.SQLMetrics
+
/**
* A graph used for storing information of an executionPlan of DataFrame.
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
* Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the
* SparkPlan tree, and each edge represents a parent-child relationship between two nodes.
*/
-private[ui] case class SparkPlanGraph(
+case class SparkPlanGraph(
nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) {
def makeDotFile(metrics: Map[Long, String]): String = {
@@ -55,7 +55,7 @@ private[ui] case class SparkPlanGraph(
}
}
-private[sql] object SparkPlanGraph {
+object SparkPlanGraph {
/**
* Build a SparkPlanGraph from the root of a SparkPlan tree.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
new file mode 100644
index 0000000000000..174378304d4a5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+/**
+ * An aggregator that uses a single associative and commutative reduce function. This reduce
+ * function can be used to go through all input values and reduces them to a single value.
+ * If there is no input, a null value is returned.
+ *
+ * This class currently assumes there is at least one input row.
+ */
+private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
+ extends Aggregator[T, (Boolean, T), T] {
+
+ private val encoder = implicitly[Encoder[T]]
+
+ override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
+
+ override def bufferEncoder: Encoder[(Boolean, T)] =
+ ExpressionEncoder.tuple(
+ ExpressionEncoder[Boolean](),
+ encoder.asInstanceOf[ExpressionEncoder[T]])
+
+ override def outputEncoder: Encoder[T] = encoder
+
+ override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
+ if (b._1) {
+ (true, func(b._2, a))
+ } else {
+ (true, a)
+ }
+ }
+
+ override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
+ if (!b1._1) {
+ b2
+ } else if (!b2._1) {
+ b1
+ } else {
+ (true, func(b1._2, b2._2))
+ }
+ }
+
+ override def finish(reduction: (Boolean, T)): T = {
+ if (!reduction._1) {
+ throw new IllegalStateException("ReduceAggregator requires at least one input row")
+ }
+ reduction._2
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index ab09ef7450b04..4e185b85e7660 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2175,7 +2175,8 @@ object functions {
def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) }
/**
- * Extract a specific(idx) group identified by a java regex, from the specified string column.
+ * Extract a specific group matched by a Java regex, from the specified string column.
+ * If the regex did not match, or the specified group did not match, an empty string is returned.
*
* @group string_funcs
* @since 1.5.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 1a9bb6a0b54e1..0666a99cfc43e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.immutable
+import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.internal.Logging
@@ -55,7 +56,7 @@ object SQLConf {
val WAREHOUSE_PATH = SQLConfigBuilder("spark.sql.warehouse.dir")
.doc("The default location for managed databases and tables.")
.stringConf
- .createWithDefault("file:${system:user.dir}/spark-warehouse")
+ .createWithDefault("${system:user.dir}/spark-warehouse")
val OPTIMIZER_MAX_ITERATIONS = SQLConfigBuilder("spark.sql.optimizer.maxIterations")
.internal()
@@ -691,9 +692,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH)
def warehousePath: String = {
- getConf(WAREHOUSE_PATH).replace("${system:user.dir}", System.getProperty("user.dir"))
+ new Path(getConf(WAREHOUSE_PATH).replace("${system:user.dir}",
+ System.getProperty("user.dir"))).toString
}
-
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 6c43fe3177d65..54aee5e02bb9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.internal
-import org.apache.hadoop.conf.Configuration
-
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
index 90f95ca9d4229..bd3e5a5618ec4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution}
/**
* :: Experimental ::
- * Exception that stopped a [[StreamingQuery]].
+ * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception
+ * that caused the failure.
* @param query Query that caused the exception
* @param message Message of this exception
* @param cause Internal cause of this exception
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index 3b3cead3a66de..db606abb8ce43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -108,6 +108,5 @@ object StreamingQueryListener {
@Experimental
class QueryTerminated private[sql](
val queryInfo: StreamingQueryInfo,
- val exception: Option[String],
- val stackTrace: Seq[StackTraceElement]) extends Event
+ val exception: Option[String]) extends Event
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 318b53cdbbaa0..c44fc3d393862 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -327,23 +327,23 @@ private String getResource(String resource) {
@Test
public void testGenericLoad() {
- Dataset df1 = spark.read().format("text").load(getResource("text-suite.txt"));
+ Dataset df1 = spark.read().format("text").load(getResource("test-data/text-suite.txt"));
Assert.assertEquals(4L, df1.count());
Dataset df2 = spark.read().format("text").load(
- getResource("text-suite.txt"),
- getResource("text-suite2.txt"));
+ getResource("test-data/text-suite.txt"),
+ getResource("test-data/text-suite2.txt"));
Assert.assertEquals(5L, df2.count());
}
@Test
public void testTextLoad() {
- Dataset ds1 = spark.read().textFile(getResource("text-suite.txt"));
+ Dataset ds1 = spark.read().textFile(getResource("test-data/text-suite.txt"));
Assert.assertEquals(4L, ds1.count());
Dataset ds2 = spark.read().textFile(
- getResource("text-suite.txt"),
- getResource("text-suite2.txt"));
+ getResource("test-data/text-suite.txt"),
+ getResource("test-data/text-suite2.txt"));
Assert.assertEquals(5L, ds2.count());
}
diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet
deleted file mode 100644
index 213f1a90291b3..0000000000000
Binary files a/sql/core/src/test/resources/old-repeated.parquet and /dev/null differ
diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
new file mode 100644
index 0000000000000..f62b10ca0037b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
@@ -0,0 +1,34 @@
+
+-- unary minus and plus
+select -100;
+select +230;
+select -5.2;
+select +6.8e0;
+select -key, +key from testdata where key = 2;
+select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1;
+select -max(key), +max(key) from testdata;
+select - (-10);
+select + (-key) from testdata where key = 32;
+select - (+max(key)) from testdata;
+select - - 3;
+select - + 20;
+select + + 100;
+select - - max(key) from testdata;
+select + - key from testdata where key = 33;
+
+-- div
+select 5 / 2;
+select 5 / 0;
+select 5 / null;
+select null / 5;
+select 5 div 2;
+select 5 div 0;
+select 5 div null;
+select null div 5;
+
+-- other arithmetics
+select 1 + 2;
+select 1 - 2;
+select 2 * 5;
+select 5 % 3;
+select pmod(-7, 3);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql b/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql
new file mode 100644
index 0000000000000..d69f8147a5264
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql
@@ -0,0 +1,4 @@
+-- This is a query file that has been blacklisted.
+-- It includes a query that should crash Spark.
+-- If the test case is run, the whole suite would fail.
+some random not working query that should crash Spark.
diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
new file mode 100644
index 0000000000000..3fd1c37e71795
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
@@ -0,0 +1,4 @@
+-- date time functions
+
+-- [SPARK-16836] current_date and current_timestamp literals
+select current_date = current_date(), current_timestamp = current_timestamp();
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
new file mode 100644
index 0000000000000..36b469c61788c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
@@ -0,0 +1,50 @@
+-- group by ordinal positions
+
+create temporary view data as select * from values
+ (1, 1),
+ (1, 2),
+ (2, 1),
+ (2, 2),
+ (3, 1),
+ (3, 2)
+ as data(a, b);
+
+-- basic case
+select a, sum(b) from data group by 1;
+
+-- constant case
+select 1, 2, sum(b) from data group by 1, 2;
+
+-- duplicate group by column
+select a, 1, sum(b) from data group by a, 1;
+select a, 1, sum(b) from data group by 1, 2;
+
+-- group by a non-aggregate expression's ordinal
+select a, b + 2, count(2) from data group by a, 2;
+
+-- with alias
+select a as aa, b + 2 as bb, count(2) from data group by 1, 2;
+
+-- foldable non-literal: this should be the same as no grouping.
+select sum(b) from data group by 1 + 0;
+
+-- negative cases: ordinal out of range
+select a, b from data group by -1;
+select a, b from data group by 0;
+select a, b from data group by 3;
+
+-- negative case: position is an aggregate expression
+select a, b, sum(b) from data group by 3;
+select a, b, sum(b) + 2 from data group by 3;
+
+-- negative case: nondeterministic expression
+select a, rand(0), sum(b) from data group by a, 2;
+
+-- negative case: star
+select * from data group by a, b, 1;
+
+-- turn of group by ordinal
+set spark.sql.groupByOrdinal=false;
+
+-- can now group by negative literal
+select sum(b) from data group by -1;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
new file mode 100644
index 0000000000000..364c022d959dc
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -0,0 +1,15 @@
+create temporary view hav as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3),
+ ("one", 5)
+ as hav(k, v);
+
+-- having clause
+SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2;
+
+-- having condition contains grouping column
+SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2;
+
+-- SPARK-11032: resolve having correctly
+SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql
new file mode 100644
index 0000000000000..892a1bb4b559f
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql
@@ -0,0 +1,20 @@
+
+-- limit on various data types
+select * from testdata limit 2;
+select * from arraydata limit 2;
+select * from mapdata limit 2;
+
+-- foldable non-literal in limit
+select * from testdata limit 2 + 1;
+
+select * from testdata limit CAST(1 AS int);
+
+-- limit must be non-negative
+select * from testdata limit -1;
+
+-- limit must be foldable
+select * from testdata limit key > 3;
+
+-- limit must be integer
+select * from testdata limit true;
+select * from testdata limit 'a';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
new file mode 100644
index 0000000000000..62f0d3d0599c6
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
@@ -0,0 +1,92 @@
+-- Literal parsing
+
+-- null
+select null, Null, nUll;
+
+-- boolean
+select true, tRue, false, fALse;
+
+-- byte (tinyint)
+select 1Y;
+select 127Y, -128Y;
+
+-- out of range byte
+select 128Y;
+
+-- short (smallint)
+select 1S;
+select 32767S, -32768S;
+
+-- out of range short
+select 32768S;
+
+-- long (bigint)
+select 1L, 2147483648L;
+select 9223372036854775807L, -9223372036854775808L;
+
+-- out of range long
+select 9223372036854775808L;
+
+-- integral parsing
+
+-- parse int
+select 1, -1;
+
+-- parse int max and min value as int
+select 2147483647, -2147483648;
+
+-- parse long max and min value as long
+select 9223372036854775807, -9223372036854775808;
+
+-- parse as decimals (Long.MaxValue + 1, and Long.MinValue - 1)
+select 9223372036854775808, -9223372036854775809;
+
+-- out of range decimal numbers
+select 1234567890123456789012345678901234567890;
+select 1234567890123456789012345678901234567890.0;
+
+-- double
+select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1;
+select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5;
+-- negative double
+select .e3;
+-- inf and -inf
+select 1E309, -1E309;
+
+-- decimal parsing
+select 0.3, -0.8, .5, -.18, 0.1111, .1111;
+
+-- super large scientific notation numbers should still be valid doubles
+select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10;
+
+-- string
+select "Hello Peter!", 'hello lee!';
+-- multi string
+select 'hello' 'world', 'hello' " " 'lee';
+-- single quote within double quotes
+select "hello 'peter'";
+select 'pattern%', 'no-pattern\%', 'pattern\\%', 'pattern\\\%';
+select '\'', '"', '\n', '\r', '\t', 'Z';
+-- "Hello!" in octals
+select '\110\145\154\154\157\041';
+-- "World :)" in unicode
+select '\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029';
+
+-- date
+select dAte '2016-03-12';
+-- invalid date
+select date 'mar 11 2016';
+
+-- timestamp
+select tImEstAmp '2016-03-11 20:54:00.000';
+-- invalid timestamp
+select timestamp '2016-33-11 20:54:00.000';
+
+-- interval
+select interval 13.123456789 seconds, interval -13.123456789 second;
+select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond;
+-- ns is not supported
+select interval 10 nanoseconds;
+
+-- unsupported data type
+select GEO '(10,-6)';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql
new file mode 100644
index 0000000000000..71a50157b766c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql
@@ -0,0 +1,20 @@
+create temporary view nt1 as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3)
+ as nt1(k, v1);
+
+create temporary view nt2 as select * from values
+ ("one", 1),
+ ("two", 22),
+ ("one", 5)
+ as nt2(k, v2);
+
+
+SELECT * FROM nt1 natural join nt2 where k = "one";
+
+SELECT * FROM nt1 natural left join nt2 order by v1, v2;
+
+SELECT * FROM nt1 natural right join nt2 order by v1, v2;
+
+SELECT count(*) FROM nt1 natural full outer join nt2;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql
new file mode 100644
index 0000000000000..8d733e77fa8d3
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql
@@ -0,0 +1,36 @@
+-- order by and sort by ordinal positions
+
+create temporary view data as select * from values
+ (1, 1),
+ (1, 2),
+ (2, 1),
+ (2, 2),
+ (3, 1),
+ (3, 2)
+ as data(a, b);
+
+select * from data order by 1 desc;
+
+-- mix ordinal and column name
+select * from data order by 1 desc, b desc;
+
+-- order by multiple ordinals
+select * from data order by 1 desc, 2 desc;
+
+-- 1 + 0 is considered a constant (not an ordinal) and thus ignored
+select * from data order by 1 + 0 desc, b desc;
+
+-- negative cases: ordinal position out of range
+select * from data order by 0;
+select * from data order by -1;
+select * from data order by 3;
+
+-- sort by ordinal
+select * from data sort by 1 desc;
+
+-- turn off order by ordinal
+set spark.sql.orderByOrdinal=false;
+
+-- 0 is now a valid literal
+select * from data order by 0;
+select * from data sort by 0;
diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
new file mode 100644
index 0000000000000..6abe048af477d
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
@@ -0,0 +1,226 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 28
+
+
+-- !query 0
+select -100
+-- !query 0 schema
+struct<-100:int>
+-- !query 0 output
+-100
+
+
+-- !query 1
+select +230
+-- !query 1 schema
+struct<230:int>
+-- !query 1 output
+230
+
+
+-- !query 2
+select -5.2
+-- !query 2 schema
+struct<-5.2:decimal(2,1)>
+-- !query 2 output
+-5.2
+
+
+-- !query 3
+select +6.8e0
+-- !query 3 schema
+struct<6.8:double>
+-- !query 3 output
+6.8
+
+
+-- !query 4
+select -key, +key from testdata where key = 2
+-- !query 4 schema
+struct<(- key):int,key:int>
+-- !query 4 output
+-2 2
+
+
+-- !query 5
+select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1
+-- !query 5 schema
+struct<(- (key + 1)):int,((- key) + 1):int,(key + 5):int>
+-- !query 5 output
+-2 0 6
+
+
+-- !query 6
+select -max(key), +max(key) from testdata
+-- !query 6 schema
+struct<(- max(key)):int,max(key):int>
+-- !query 6 output
+-100 100
+
+
+-- !query 7
+select - (-10)
+-- !query 7 schema
+struct<(- -10):int>
+-- !query 7 output
+10
+
+
+-- !query 8
+select + (-key) from testdata where key = 32
+-- !query 8 schema
+struct<(- key):int>
+-- !query 8 output
+-32
+
+
+-- !query 9
+select - (+max(key)) from testdata
+-- !query 9 schema
+struct<(- max(key)):int>
+-- !query 9 output
+-100
+
+
+-- !query 10
+select - - 3
+-- !query 10 schema
+struct<(- -3):int>
+-- !query 10 output
+3
+
+
+-- !query 11
+select - + 20
+-- !query 11 schema
+struct<(- 20):int>
+-- !query 11 output
+-20
+
+
+-- !query 12
+select + + 100
+-- !query 12 schema
+struct<100:int>
+-- !query 12 output
+100
+
+
+-- !query 13
+select - - max(key) from testdata
+-- !query 13 schema
+struct<(- (- max(key))):int>
+-- !query 13 output
+100
+
+
+-- !query 14
+select + - key from testdata where key = 33
+-- !query 14 schema
+struct<(- key):int>
+-- !query 14 output
+-33
+
+
+-- !query 15
+select 5 / 2
+-- !query 15 schema
+struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double>
+-- !query 15 output
+2.5
+
+
+-- !query 16
+select 5 / 0
+-- !query 16 schema
+struct<(CAST(5 AS DOUBLE) / CAST(0 AS DOUBLE)):double>
+-- !query 16 output
+NULL
+
+
+-- !query 17
+select 5 / null
+-- !query 17 schema
+struct<(CAST(5 AS DOUBLE) / CAST(NULL AS DOUBLE)):double>
+-- !query 17 output
+NULL
+
+
+-- !query 18
+select null / 5
+-- !query 18 schema
+struct<(CAST(NULL AS DOUBLE) / CAST(5 AS DOUBLE)):double>
+-- !query 18 output
+NULL
+
+
+-- !query 19
+select 5 div 2
+-- !query 19 schema
+struct
+-- !query 19 output
+2
+
+
+-- !query 20
+select 5 div 0
+-- !query 20 schema
+struct
+-- !query 20 output
+NULL
+
+
+-- !query 21
+select 5 div null
+-- !query 21 schema
+struct
+-- !query 21 output
+NULL
+
+
+-- !query 22
+select null div 5
+-- !query 22 schema
+struct
+-- !query 22 output
+NULL
+
+
+-- !query 23
+select 1 + 2
+-- !query 23 schema
+struct<(1 + 2):int>
+-- !query 23 output
+3
+
+
+-- !query 24
+select 1 - 2
+-- !query 24 schema
+struct<(1 - 2):int>
+-- !query 24 output
+-1
+
+
+-- !query 25
+select 2 * 5
+-- !query 25 schema
+struct<(2 * 5):int>
+-- !query 25 output
+10
+
+
+-- !query 26
+select 5 % 3
+-- !query 26 schema
+struct<(5 % 3):int>
+-- !query 26 output
+2
+
+
+-- !query 27
+select pmod(-7, 3)
+-- !query 27 schema
+struct
+-- !query 27 output
+2
diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
new file mode 100644
index 0000000000000..032e4258500fb
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
@@ -0,0 +1,10 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 1
+
+
+-- !query 0
+select current_date = current_date(), current_timestamp = current_timestamp()
+-- !query 0 schema
+struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean>
+-- !query 0 output
+true true
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
new file mode 100644
index 0000000000000..2f10b7ebc6d32
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -0,0 +1,168 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 17
+
+
+-- !query 0
+create temporary view data as select * from values
+ (1, 1),
+ (1, 2),
+ (2, 1),
+ (2, 2),
+ (3, 1),
+ (3, 2)
+ as data(a, b)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select a, sum(b) from data group by 1
+-- !query 1 schema
+struct
+-- !query 1 output
+1 3
+2 3
+3 3
+
+
+-- !query 2
+select 1, 2, sum(b) from data group by 1, 2
+-- !query 2 schema
+struct<1:int,2:int,sum(b):bigint>
+-- !query 2 output
+1 2 9
+
+
+-- !query 3
+select a, 1, sum(b) from data group by a, 1
+-- !query 3 schema
+struct
+-- !query 3 output
+1 1 3
+2 1 3
+3 1 3
+
+
+-- !query 4
+select a, 1, sum(b) from data group by 1, 2
+-- !query 4 schema
+struct
+-- !query 4 output
+1 1 3
+2 1 3
+3 1 3
+
+
+-- !query 5
+select a, b + 2, count(2) from data group by a, 2
+-- !query 5 schema
+struct
+-- !query 5 output
+1 3 1
+1 4 1
+2 3 1
+2 4 1
+3 3 1
+3 4 1
+
+
+-- !query 6
+select a as aa, b + 2 as bb, count(2) from data group by 1, 2
+-- !query 6 schema
+struct
+-- !query 6 output
+1 3 1
+1 4 1
+2 3 1
+2 4 1
+3 3 1
+3 4 1
+
+
+-- !query 7
+select sum(b) from data group by 1 + 0
+-- !query 7 schema
+struct
+-- !query 7 output
+9
+
+
+-- !query 8
+select a, b from data group by -1
+-- !query 8 schema
+struct<>
+-- !query 8 output
+org.apache.spark.sql.AnalysisException
+GROUP BY position -1 is not in select list (valid range is [1, 2]); line 1 pos 31
+
+
+-- !query 9
+select a, b from data group by 0
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+GROUP BY position 0 is not in select list (valid range is [1, 2]); line 1 pos 31
+
+
+-- !query 10
+select a, b from data group by 3
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.AnalysisException
+GROUP BY position 3 is not in select list (valid range is [1, 2]); line 1 pos 31
+
+
+-- !query 11
+select a, b, sum(b) from data group by 3
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39
+
+
+-- !query 12
+select a, b, sum(b) + 2 from data group by 3
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.AnalysisException
+GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43
+
+
+-- !query 13
+select a, rand(0), sum(b) from data group by a, 2
+-- !query 13 schema
+struct<>
+-- !query 13 output
+org.apache.spark.sql.AnalysisException
+nondeterministic expression rand(0) should not appear in grouping expression.;
+
+
+-- !query 14
+select * from data group by a, b, 1
+-- !query 14 schema
+struct<>
+-- !query 14 output
+org.apache.spark.sql.AnalysisException
+Star (*) is not allowed in select list when GROUP BY ordinal position is used;
+
+
+-- !query 15
+set spark.sql.groupByOrdinal=false
+-- !query 15 schema
+struct
+-- !query 15 output
+spark.sql.groupByOrdinal
+
+
+-- !query 16
+select sum(b) from data group by -1
+-- !query 16 schema
+struct
+-- !query 16 output
+9
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
new file mode 100644
index 0000000000000..e0923832673cb
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -0,0 +1,40 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 4
+
+
+-- !query 0
+create temporary view hav as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3),
+ ("one", 5)
+ as hav(k, v)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2
+-- !query 1 schema
+struct
+-- !query 1 output
+one 6
+three 3
+
+
+-- !query 2
+SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2
+-- !query 2 schema
+struct
+-- !query 2 output
+1
+
+
+-- !query 3
+SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0)
+-- !query 3 schema
+struct
+-- !query 3 output
+1
diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out
new file mode 100644
index 0000000000000..b71b05886986c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out
@@ -0,0 +1,83 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 9
+
+
+-- !query 0
+select * from testdata limit 2
+-- !query 0 schema
+struct
+-- !query 0 output
+1 1
+2 2
+
+
+-- !query 1
+select * from arraydata limit 2
+-- !query 1 schema
+struct,nestedarraycol:array>>
+-- !query 1 output
+[1,2,3] [[1,2,3]]
+[2,3,4] [[2,3,4]]
+
+
+-- !query 2
+select * from mapdata limit 2
+-- !query 2 schema
+struct>
+-- !query 2 output
+{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"}
+{1:"a2",2:"b2",3:"c2",4:"d2"}
+
+
+-- !query 3
+select * from testdata limit 2 + 1
+-- !query 3 schema
+struct
+-- !query 3 output
+1 1
+2 2
+3 3
+
+
+-- !query 4
+select * from testdata limit CAST(1 AS int)
+-- !query 4 schema
+struct
+-- !query 4 output
+1 1
+
+
+-- !query 5
+select * from testdata limit -1
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+The limit expression must be equal to or greater than 0, but got -1;
+
+
+-- !query 6
+select * from testdata limit key > 3
+-- !query 6 schema
+struct<>
+-- !query 6 output
+org.apache.spark.sql.AnalysisException
+The limit expression must evaluate to a constant value, but got (testdata.`key` > 3);
+
+
+-- !query 7
+select * from testdata limit true
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.AnalysisException
+The limit expression must be integer type, but got boolean;
+
+
+-- !query 8
+select * from testdata limit 'a'
+-- !query 8 schema
+struct<>
+-- !query 8 output
+org.apache.spark.sql.AnalysisException
+The limit expression must be integer type, but got string;
diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
new file mode 100644
index 0000000000000..b964a6fc0921f
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
@@ -0,0 +1,356 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 38
+
+
+-- !query 0
+select null, Null, nUll
+-- !query 0 schema
+struct
+-- !query 0 output
+NULL NULL NULL
+
+
+-- !query 1
+select true, tRue, false, fALse
+-- !query 1 schema
+struct
+-- !query 1 output
+true true false false
+
+
+-- !query 2
+select 1Y
+-- !query 2 schema
+struct<1:tinyint>
+-- !query 2 output
+1
+
+
+-- !query 3
+select 127Y, -128Y
+-- !query 3 schema
+struct<127:tinyint,-128:tinyint>
+-- !query 3 output
+127 -128
+
+
+-- !query 4
+select 128Y
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Value out of range. Value:"128" Radix:10(line 1, pos 7)
+
+== SQL ==
+select 128Y
+-------^^^
+
+
+-- !query 5
+select 1S
+-- !query 5 schema
+struct<1:smallint>
+-- !query 5 output
+1
+
+
+-- !query 6
+select 32767S, -32768S
+-- !query 6 schema
+struct<32767:smallint,-32768:smallint>
+-- !query 6 output
+32767 -32768
+
+
+-- !query 7
+select 32768S
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Value out of range. Value:"32768" Radix:10(line 1, pos 7)
+
+== SQL ==
+select 32768S
+-------^^^
+
+
+-- !query 8
+select 1L, 2147483648L
+-- !query 8 schema
+struct<1:bigint,2147483648:bigint>
+-- !query 8 output
+1 2147483648
+
+
+-- !query 9
+select 9223372036854775807L, -9223372036854775808L
+-- !query 9 schema
+struct<9223372036854775807:bigint,-9223372036854775808:bigint>
+-- !query 9 output
+9223372036854775807 -9223372036854775808
+
+
+-- !query 10
+select 9223372036854775808L
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+For input string: "9223372036854775808"(line 1, pos 7)
+
+== SQL ==
+select 9223372036854775808L
+-------^^^
+
+
+-- !query 11
+select 1, -1
+-- !query 11 schema
+struct<1:int,-1:int>
+-- !query 11 output
+1 -1
+
+
+-- !query 12
+select 2147483647, -2147483648
+-- !query 12 schema
+struct<2147483647:int,-2147483648:int>
+-- !query 12 output
+2147483647 -2147483648
+
+
+-- !query 13
+select 9223372036854775807, -9223372036854775808
+-- !query 13 schema
+struct<9223372036854775807:bigint,-9223372036854775808:bigint>
+-- !query 13 output
+9223372036854775807 -9223372036854775808
+
+
+-- !query 14
+select 9223372036854775808, -9223372036854775809
+-- !query 14 schema
+struct<9223372036854775808:decimal(19,0),-9223372036854775809:decimal(19,0)>
+-- !query 14 output
+9223372036854775808 -9223372036854775809
+
+
+-- !query 15
+select 1234567890123456789012345678901234567890
+-- !query 15 schema
+struct<>
+-- !query 15 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+DecimalType can only support precision up to 38
+== SQL ==
+select 1234567890123456789012345678901234567890
+
+
+-- !query 16
+select 1234567890123456789012345678901234567890.0
+-- !query 16 schema
+struct<>
+-- !query 16 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+DecimalType can only support precision up to 38
+== SQL ==
+select 1234567890123456789012345678901234567890.0
+
+
+-- !query 17
+select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1
+-- !query 17 schema
+struct<1.0:double,1.2:double,1.0E10:double,150000.0:double,0.1:double,0.1:double,10000.0:double,90.0:double,90.0:double,90.0:double,90.0:double>
+-- !query 17 output
+1.0 1.2 1.0E10 150000.0 0.1 0.1 10000.0 90.0 90.0 90.0 90.0
+
+
+-- !query 18
+select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5
+-- !query 18 schema
+struct<-1.0:double,-1.2:double,-1.0E10:double,-150000.0:double,-0.1:double,-0.1:double,-10000.0:double>
+-- !query 18 output
+-1.0 -1.2 -1.0E10 -150000.0 -0.1 -0.1 -10000.0
+
+
+-- !query 19
+select .e3
+-- !query 19 schema
+struct<>
+-- !query 19 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+no viable alternative at input 'select .'(line 1, pos 7)
+
+== SQL ==
+select .e3
+-------^^^
+
+
+-- !query 20
+select 1E309, -1E309
+-- !query 20 schema
+struct
+-- !query 20 output
+Infinity -Infinity
+
+
+-- !query 21
+select 0.3, -0.8, .5, -.18, 0.1111, .1111
+-- !query 21 schema
+struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0.1111:decimal(4,4),0.1111:decimal(4,4)>
+-- !query 21 output
+0.3 -0.8 0.5 -0.18 0.1111 0.1111
+
+
+-- !query 22
+select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10
+-- !query 22 schema
+struct<1.2345678901234568E48:double,1.2345678901234568E48:double>
+-- !query 22 output
+1.2345678901234568E48 1.2345678901234568E48
+
+
+-- !query 23
+select "Hello Peter!", 'hello lee!'
+-- !query 23 schema
+struct
+-- !query 23 output
+Hello Peter! hello lee!
+
+
+-- !query 24
+select 'hello' 'world', 'hello' " " 'lee'
+-- !query 24 schema
+struct
+-- !query 24 output
+helloworld hello lee
+
+
+-- !query 25
+select "hello 'peter'"
+-- !query 25 schema
+struct
+-- !query 25 output
+hello 'peter'
+
+
+-- !query 26
+select 'pattern%', 'no-pattern\%', 'pattern\\%', 'pattern\\\%'
+-- !query 26 schema
+struct
+-- !query 26 output
+pattern% no-pattern\% pattern\% pattern\\%
+
+
+-- !query 27
+select '\'', '"', '\n', '\r', '\t', 'Z'
+-- !query 27 schema
+struct<':string,":string,
+:string,
:string, :string,Z:string>
+-- !query 27 output
+' "
+
Z
+
+
+-- !query 28
+select '\110\145\154\154\157\041'
+-- !query 28 schema
+struct
+-- !query 28 output
+Hello!
+
+
+-- !query 29
+select '\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'
+-- !query 29 schema
+struct
+-- !query 29 output
+World :)
+
+
+-- !query 30
+select dAte '2016-03-12'
+-- !query 30 schema
+struct
+-- !query 30 output
+2016-03-12
+
+
+-- !query 31
+select date 'mar 11 2016'
+-- !query 31 schema
+struct<>
+-- !query 31 output
+java.lang.IllegalArgumentException
+null
+
+
+-- !query 32
+select tImEstAmp '2016-03-11 20:54:00.000'
+-- !query 32 schema
+struct
+-- !query 32 output
+2016-03-11 20:54:00
+
+
+-- !query 33
+select timestamp '2016-33-11 20:54:00.000'
+-- !query 33 schema
+struct<>
+-- !query 33 output
+java.lang.IllegalArgumentException
+Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]
+
+
+-- !query 34
+select interval 13.123456789 seconds, interval -13.123456789 second
+-- !query 34 schema
+struct<>
+-- !query 34 output
+scala.MatchError
+(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2)
+
+
+-- !query 35
+select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond
+-- !query 35 schema
+struct<>
+-- !query 35 output
+scala.MatchError
+(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2)
+
+
+-- !query 36
+select interval 10 nanoseconds
+-- !query 36 schema
+struct<>
+-- !query 36 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+No interval can be constructed(line 1, pos 16)
+
+== SQL ==
+select interval 10 nanoseconds
+----------------^^^
+
+
+-- !query 37
+select GEO '(10,-6)'
+-- !query 37 schema
+struct<>
+-- !query 37 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Literals of type 'GEO' are currently not supported.(line 1, pos 7)
+
+== SQL ==
+select GEO '(10,-6)'
+-------^^^
diff --git a/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out b/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out
new file mode 100644
index 0000000000000..43f2f9af61d9b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out
@@ -0,0 +1,64 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+create temporary view nt1 as select * from values
+ ("one", 1),
+ ("two", 2),
+ ("three", 3)
+ as nt1(k, v1)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+create temporary view nt2 as select * from values
+ ("one", 1),
+ ("two", 22),
+ ("one", 5)
+ as nt2(k, v2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT * FROM nt1 natural join nt2 where k = "one"
+-- !query 2 schema
+struct
+-- !query 2 output
+one 1 1
+one 1 5
+
+
+-- !query 3
+SELECT * FROM nt1 natural left join nt2 order by v1, v2
+-- !query 3 schema
+struct
+-- !query 3 output
+one 1 1
+one 1 5
+two 2 22
+three 3 NULL
+
+
+-- !query 4
+SELECT * FROM nt1 natural right join nt2 order by v1, v2
+-- !query 4 schema
+struct
+-- !query 4 output
+one 1 1
+one 1 5
+two 2 22
+
+
+-- !query 5
+SELECT count(*) FROM nt1 natural full outer join nt2
+-- !query 5 schema
+struct
+-- !query 5 output
+4
diff --git a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out
new file mode 100644
index 0000000000000..03a4e72d0fa3e
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out
@@ -0,0 +1,143 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 12
+
+
+-- !query 0
+create temporary view data as select * from values
+ (1, 1),
+ (1, 2),
+ (2, 1),
+ (2, 2),
+ (3, 1),
+ (3, 2)
+ as data(a, b)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select * from data order by 1 desc
+-- !query 1 schema
+struct
+-- !query 1 output
+3 1
+3 2
+2 1
+2 2
+1 1
+1 2
+
+
+-- !query 2
+select * from data order by 1 desc, b desc
+-- !query 2 schema
+struct
+-- !query 2 output
+3 2
+3 1
+2 2
+2 1
+1 2
+1 1
+
+
+-- !query 3
+select * from data order by 1 desc, 2 desc
+-- !query 3 schema
+struct
+-- !query 3 output
+3 2
+3 1
+2 2
+2 1
+1 2
+1 1
+
+
+-- !query 4
+select * from data order by 1 + 0 desc, b desc
+-- !query 4 schema
+struct
+-- !query 4 output
+1 2
+2 2
+3 2
+1 1
+2 1
+3 1
+
+
+-- !query 5
+select * from data order by 0
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+ORDER BY position 0 is not in select list (valid range is [1, 2]); line 1 pos 28
+
+
+-- !query 6
+select * from data order by -1
+-- !query 6 schema
+struct<>
+-- !query 6 output
+org.apache.spark.sql.AnalysisException
+ORDER BY position -1 is not in select list (valid range is [1, 2]); line 1 pos 28
+
+
+-- !query 7
+select * from data order by 3
+-- !query 7 schema
+struct<>
+-- !query 7 output
+org.apache.spark.sql.AnalysisException
+ORDER BY position 3 is not in select list (valid range is [1, 2]); line 1 pos 28
+
+
+-- !query 8
+select * from data sort by 1 desc
+-- !query 8 schema
+struct
+-- !query 8 output
+1 1
+1 2
+2 1
+2 2
+3 1
+3 2
+
+
+-- !query 9
+set spark.sql.orderByOrdinal=false
+-- !query 9 schema
+struct
+-- !query 9 output
+spark.sql.orderByOrdinal
+
+
+-- !query 10
+select * from data order by 0
+-- !query 10 schema
+struct
+-- !query 10 output
+1 1
+1 2
+2 1
+2 2
+3 1
+3 2
+
+
+-- !query 11
+select * from data sort by 0
+-- !query 11 schema
+struct
+-- !query 11 output
+1 1
+1 2
+2 1
+2 2
+3 1
+3 2
diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/test-data/bool.csv
similarity index 100%
rename from sql/core/src/test/resources/bool.csv
rename to sql/core/src/test/resources/test-data/bool.csv
diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/test-data/cars-alternative.csv
similarity index 100%
rename from sql/core/src/test/resources/cars-alternative.csv
rename to sql/core/src/test/resources/test-data/cars-alternative.csv
diff --git a/sql/core/src/test/resources/cars-blank-column-name.csv b/sql/core/src/test/resources/test-data/cars-blank-column-name.csv
similarity index 100%
rename from sql/core/src/test/resources/cars-blank-column-name.csv
rename to sql/core/src/test/resources/test-data/cars-blank-column-name.csv
diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/test-data/cars-malformed.csv
similarity index 100%
rename from sql/core/src/test/resources/cars-malformed.csv
rename to sql/core/src/test/resources/test-data/cars-malformed.csv
diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/test-data/cars-null.csv
similarity index 100%
rename from sql/core/src/test/resources/cars-null.csv
rename to sql/core/src/test/resources/test-data/cars-null.csv
diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/test-data/cars-unbalanced-quotes.csv
similarity index 100%
rename from sql/core/src/test/resources/cars-unbalanced-quotes.csv
rename to sql/core/src/test/resources/test-data/cars-unbalanced-quotes.csv
diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/test-data/cars.csv
similarity index 100%
rename from sql/core/src/test/resources/cars.csv
rename to sql/core/src/test/resources/test-data/cars.csv
diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/test-data/cars.tsv
similarity index 100%
rename from sql/core/src/test/resources/cars.tsv
rename to sql/core/src/test/resources/test-data/cars.tsv
diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/test-data/cars_iso-8859-1.csv
similarity index 100%
rename from sql/core/src/test/resources/cars_iso-8859-1.csv
rename to sql/core/src/test/resources/test-data/cars_iso-8859-1.csv
diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/test-data/comments.csv
similarity index 100%
rename from sql/core/src/test/resources/comments.csv
rename to sql/core/src/test/resources/test-data/comments.csv
diff --git a/sql/core/src/test/resources/dates.csv b/sql/core/src/test/resources/test-data/dates.csv
similarity index 100%
rename from sql/core/src/test/resources/dates.csv
rename to sql/core/src/test/resources/test-data/dates.csv
diff --git a/sql/core/src/test/resources/dec-in-fixed-len.parquet b/sql/core/src/test/resources/test-data/dec-in-fixed-len.parquet
similarity index 100%
rename from sql/core/src/test/resources/dec-in-fixed-len.parquet
rename to sql/core/src/test/resources/test-data/dec-in-fixed-len.parquet
diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/test-data/dec-in-i32.parquet
similarity index 100%
rename from sql/core/src/test/resources/dec-in-i32.parquet
rename to sql/core/src/test/resources/test-data/dec-in-i32.parquet
diff --git a/sql/core/src/test/resources/dec-in-i64.parquet b/sql/core/src/test/resources/test-data/dec-in-i64.parquet
similarity index 100%
rename from sql/core/src/test/resources/dec-in-i64.parquet
rename to sql/core/src/test/resources/test-data/dec-in-i64.parquet
diff --git a/sql/core/src/test/resources/decimal.csv b/sql/core/src/test/resources/test-data/decimal.csv
similarity index 100%
rename from sql/core/src/test/resources/decimal.csv
rename to sql/core/src/test/resources/test-data/decimal.csv
diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/test-data/disable_comments.csv
similarity index 100%
rename from sql/core/src/test/resources/disable_comments.csv
rename to sql/core/src/test/resources/test-data/disable_comments.csv
diff --git a/sql/core/src/test/resources/empty.csv b/sql/core/src/test/resources/test-data/empty.csv
similarity index 100%
rename from sql/core/src/test/resources/empty.csv
rename to sql/core/src/test/resources/test-data/empty.csv
diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/test-data/nested-array-struct.parquet
similarity index 100%
rename from sql/core/src/test/resources/nested-array-struct.parquet
rename to sql/core/src/test/resources/test-data/nested-array-struct.parquet
diff --git a/sql/core/src/test/resources/numbers.csv b/sql/core/src/test/resources/test-data/numbers.csv
similarity index 100%
rename from sql/core/src/test/resources/numbers.csv
rename to sql/core/src/test/resources/test-data/numbers.csv
diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/test-data/old-repeated-int.parquet
similarity index 100%
rename from sql/core/src/test/resources/old-repeated-int.parquet
rename to sql/core/src/test/resources/test-data/old-repeated-int.parquet
diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/test-data/old-repeated-message.parquet
similarity index 100%
rename from sql/core/src/test/resources/old-repeated-message.parquet
rename to sql/core/src/test/resources/test-data/old-repeated-message.parquet
diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/test-data/parquet-thrift-compat.snappy.parquet
similarity index 100%
rename from sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet
rename to sql/core/src/test/resources/test-data/parquet-thrift-compat.snappy.parquet
diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/test-data/proto-repeated-string.parquet
similarity index 100%
rename from sql/core/src/test/resources/proto-repeated-string.parquet
rename to sql/core/src/test/resources/test-data/proto-repeated-string.parquet
diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/test-data/proto-repeated-struct.parquet
similarity index 100%
rename from sql/core/src/test/resources/proto-repeated-struct.parquet
rename to sql/core/src/test/resources/test-data/proto-repeated-struct.parquet
diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/test-data/proto-struct-with-array-many.parquet
similarity index 100%
rename from sql/core/src/test/resources/proto-struct-with-array-many.parquet
rename to sql/core/src/test/resources/test-data/proto-struct-with-array-many.parquet
diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/test-data/proto-struct-with-array.parquet
similarity index 100%
rename from sql/core/src/test/resources/proto-struct-with-array.parquet
rename to sql/core/src/test/resources/test-data/proto-struct-with-array.parquet
diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/test-data/simple_sparse.csv
similarity index 100%
rename from sql/core/src/test/resources/simple_sparse.csv
rename to sql/core/src/test/resources/test-data/simple_sparse.csv
diff --git a/sql/core/src/test/resources/text-partitioned/year=2014/data.txt b/sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt
similarity index 100%
rename from sql/core/src/test/resources/text-partitioned/year=2014/data.txt
rename to sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt
diff --git a/sql/core/src/test/resources/text-partitioned/year=2015/data.txt b/sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt
similarity index 100%
rename from sql/core/src/test/resources/text-partitioned/year=2015/data.txt
rename to sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt
diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/test-data/text-suite.txt
similarity index 100%
rename from sql/core/src/test/resources/text-suite.txt
rename to sql/core/src/test/resources/test-data/text-suite.txt
diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/test-data/text-suite2.txt
similarity index 100%
rename from sql/core/src/test/resources/text-suite2.txt
rename to sql/core/src/test/resources/test-data/text-suite2.txt
diff --git a/sql/core/src/test/resources/unescaped-quotes.csv b/sql/core/src/test/resources/test-data/unescaped-quotes.csv
similarity index 100%
rename from sql/core/src/test/resources/unescaped-quotes.csv
rename to sql/core/src/test/resources/test-data/unescaped-quotes.csv
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
index 72f676e6225ee..1230b921aa279 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -58,4 +59,43 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
val nullIntRow = df.selectExpr("i[1]").collect()(0)
assert(nullIntRow == org.apache.spark.sql.Row(null))
}
+
+ test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") {
+ val ds100_5 = Seq(S100_5()).toDS()
+ ds100_5.rdd.count
+ }
}
+
+class S100(
+ val s1: String = "1", val s2: String = "2", val s3: String = "3", val s4: String = "4",
+ val s5: String = "5", val s6: String = "6", val s7: String = "7", val s8: String = "8",
+ val s9: String = "9", val s10: String = "10", val s11: String = "11", val s12: String = "12",
+ val s13: String = "13", val s14: String = "14", val s15: String = "15", val s16: String = "16",
+ val s17: String = "17", val s18: String = "18", val s19: String = "19", val s20: String = "20",
+ val s21: String = "21", val s22: String = "22", val s23: String = "23", val s24: String = "24",
+ val s25: String = "25", val s26: String = "26", val s27: String = "27", val s28: String = "28",
+ val s29: String = "29", val s30: String = "30", val s31: String = "31", val s32: String = "32",
+ val s33: String = "33", val s34: String = "34", val s35: String = "35", val s36: String = "36",
+ val s37: String = "37", val s38: String = "38", val s39: String = "39", val s40: String = "40",
+ val s41: String = "41", val s42: String = "42", val s43: String = "43", val s44: String = "44",
+ val s45: String = "45", val s46: String = "46", val s47: String = "47", val s48: String = "48",
+ val s49: String = "49", val s50: String = "50", val s51: String = "51", val s52: String = "52",
+ val s53: String = "53", val s54: String = "54", val s55: String = "55", val s56: String = "56",
+ val s57: String = "57", val s58: String = "58", val s59: String = "59", val s60: String = "60",
+ val s61: String = "61", val s62: String = "62", val s63: String = "63", val s64: String = "64",
+ val s65: String = "65", val s66: String = "66", val s67: String = "67", val s68: String = "68",
+ val s69: String = "69", val s70: String = "70", val s71: String = "71", val s72: String = "72",
+ val s73: String = "73", val s74: String = "74", val s75: String = "75", val s76: String = "76",
+ val s77: String = "77", val s78: String = "78", val s79: String = "79", val s80: String = "80",
+ val s81: String = "81", val s82: String = "82", val s83: String = "83", val s84: String = "84",
+ val s85: String = "85", val s86: String = "86", val s87: String = "87", val s88: String = "88",
+ val s89: String = "89", val s90: String = "90", val s91: String = "91", val s92: String = "92",
+ val s93: String = "93", val s94: String = "94", val s95: String = "95", val s96: String = "96",
+ val s97: String = "97", val s98: String = "98", val s99: String = "99", val s100: String = "100")
+extends DefinedByConstructorParams
+
+case class S100_5(
+ s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(),
+ s4: S100 = new S100(), s5: S100 = new S100())
+
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 43cbc03b7aa0c..9aeeda4463afc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -844,6 +844,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = spark.createDataset(data)(enc)
checkDataset(ds, (("a", "b"), "c"), (null, "d"))
}
+
+ test("SPARK-16995: flat mapping on Dataset containing a column created with lit/expr") {
+ val df = Seq("1").toDF("a")
+
+ import df.sparkSession.implicits._
+
+ checkDataset(
+ df.withColumn("b", lit(0)).as[ClassData]
+ .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
+ checkDataset(
+ df.withColumn("b", expr("0")).as[ClassData]
+ .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
+ }
}
case class Generic[T](id: T, value: Double)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index e8480a7001760..b2c051d11cf20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -242,9 +242,10 @@ abstract class QueryTest extends PlanTest {
case p if p.getClass.getSimpleName == "MetastoreRelation" => return
case _: MemoryPlan => return
}.transformAllExpressions {
- case a: ImperativeAggregate => return
+ case _: ImperativeAggregate => return
case _: TypedAggregateExpression => return
case Literal(_, _: ObjectType) => return
+ case _: UserDefinedGenerator => return
}
// bypass hive tests before we fix all corner cases in hive module.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 393b4af285498..de1a811d642bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import java.io.File
import java.math.MathContext
import java.sql.{Date, Timestamp}
@@ -38,26 +39,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
setupTestData()
- test("having clause") {
- withTempView("hav") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v")
- .createOrReplaceTempView("hav")
- checkAnswer(
- sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"),
- Row("one", 6) :: Row("three", 3) :: Nil)
- }
- }
-
- test("having condition contains grouping column") {
- withTempView("hav") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v")
- .createOrReplaceTempView("hav")
- checkAnswer(
- sql("SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2"),
- Row(1) :: Nil)
- }
- }
-
test("SPARK-8010: promote numeric to string") {
val df = Seq((1, 1)).toDF("key", "value")
df.createOrReplaceTempView("src")
@@ -507,103 +488,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}
- test("Group By Ordinal - basic") {
- checkAnswer(
- sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"),
- sql("SELECT a, sum(b) FROM testData2 GROUP BY a"))
-
- // duplicate group-by columns
- checkAnswer(
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
-
- checkAnswer(
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"),
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
- }
-
- test("Group By Ordinal - non aggregate expressions") {
- checkAnswer(
- sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"),
- sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
-
- checkAnswer(
- sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"),
- sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
- }
-
- test("Group By Ordinal - non-foldable constant expression") {
- checkAnswer(
- sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"),
- sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
-
- checkAnswer(
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
- sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
- }
-
- test("Group By Ordinal - alias") {
- checkAnswer(
- sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"),
- sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
-
- checkAnswer(
- sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"),
- sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
- }
-
- test("Group By Ordinal - constants") {
- checkAnswer(
- sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
- sql("SELECT 1, 2, sum(b) FROM testData2"))
- }
-
- test("Group By Ordinal - negative cases") {
- intercept[UnresolvedException[Aggregate]] {
- sql("SELECT a, b FROM testData2 GROUP BY -1")
- }
-
- intercept[UnresolvedException[Aggregate]] {
- sql("SELECT a, b FROM testData2 GROUP BY 3")
- }
-
- var e = intercept[UnresolvedException[Aggregate]](
- sql("SELECT SUM(a) FROM testData2 GROUP BY 1"))
- assert(e.getMessage contains
- "Invalid call to Group by position: the '1'th column in the select contains " +
- "an aggregate function")
-
- e = intercept[UnresolvedException[Aggregate]](
- sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1"))
- assert(e.getMessage contains
- "Invalid call to Group by position: the '1'th column in the select contains " +
- "an aggregate function")
-
- var ae = intercept[AnalysisException](
- sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2"))
- assert(ae.getMessage contains
- "nondeterministic expression rand(0) should not appear in grouping expression")
-
- ae = intercept[AnalysisException](
- sql("SELECT * FROM testData2 GROUP BY a, b, 1"))
- assert(ae.getMessage contains
- "Group by position: star is not allowed to use in the select list " +
- "when using ordinals in group by")
- }
-
- test("Group By Ordinal: spark.sql.groupByOrdinal=false") {
- withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
- // If spark.sql.groupByOrdinal=false, ignore the position number.
- intercept[AnalysisException] {
- sql("SELECT a, sum(b) FROM testData2 GROUP BY 1")
- }
- // '*' is not allowed to use in the select list when users specify ordinals in group by
- checkAnswer(
- sql("SELECT * FROM testData2 GROUP BY a, b, 1"),
- sql("SELECT * FROM testData2 GROUP BY a, b"))
- }
- }
-
test("aggregates with nulls") {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
@@ -670,51 +554,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sortTest()
}
- test("limit") {
- checkAnswer(
- sql("SELECT * FROM testData LIMIT 9 + 1"),
- testData.take(10).toSeq)
-
- checkAnswer(
- sql("SELECT * FROM arrayData LIMIT CAST(1 AS Integer)"),
- arrayData.collect().take(1).map(Row.fromTuple).toSeq)
-
- checkAnswer(
- sql("SELECT * FROM mapData LIMIT 1"),
- mapData.collect().take(1).map(Row.fromTuple).toSeq)
- }
-
- test("non-foldable expressions in LIMIT") {
- val e = intercept[AnalysisException] {
- sql("SELECT * FROM testData LIMIT key > 3")
- }.getMessage
- assert(e.contains("The limit expression must evaluate to a constant value, " +
- "but got (testdata.`key` > 3)"))
- }
-
- test("Expressions in limit clause are not integer") {
- var e = intercept[AnalysisException] {
- sql("SELECT * FROM testData LIMIT true")
- }.getMessage
- assert(e.contains("The limit expression must be integer type, but got boolean"))
-
- e = intercept[AnalysisException] {
- sql("SELECT * FROM testData LIMIT 'a'")
- }.getMessage
- assert(e.contains("The limit expression must be integer type, but got string"))
- }
-
test("negative in LIMIT or TABLESAMPLE") {
val expected = "The limit expression must be equal to or greater than 0, but got -1"
var e = intercept[AnalysisException] {
sql("SELECT * FROM testData TABLESAMPLE (-1 rows)")
}.getMessage
assert(e.contains(expected))
-
- e = intercept[AnalysisException] {
- sql("SELECT * FROM testData LIMIT -1")
- }.getMessage
- assert(e.contains(expected))
}
test("CTE feature") {
@@ -1347,136 +1192,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
- test("Test to check we can use Long.MinValue") {
- checkAnswer(
- sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue)
- )
-
- checkAnswer(
- sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"),
- (1 to 100).map(Row(_)).toSeq
- )
- }
-
- test("Floating point number format") {
- checkAnswer(
- sql("SELECT 0.3"), Row(BigDecimal(0.3))
- )
-
- checkAnswer(
- sql("SELECT -0.8"), Row(BigDecimal(-0.8))
- )
-
- checkAnswer(
- sql("SELECT .5"), Row(BigDecimal(0.5))
- )
-
- checkAnswer(
- sql("SELECT -.18"), Row(BigDecimal(-0.18))
- )
- }
-
- test("Auto cast integer type") {
- checkAnswer(
- sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L)
- )
-
- checkAnswer(
- sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L)
- )
-
- checkAnswer(
- sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808"))
- )
-
- checkAnswer(
- sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809"))
- )
- }
-
- test("Test to check we can apply sign to expression") {
-
- checkAnswer(
- sql("SELECT -100"), Row(-100)
- )
-
- checkAnswer(
- sql("SELECT +230"), Row(230)
- )
-
- checkAnswer(
- sql("SELECT -5.2"), Row(BigDecimal(-5.2))
- )
-
- checkAnswer(
- sql("SELECT +6.8e0"), Row(6.8d)
- )
-
- checkAnswer(
- sql("SELECT -key FROM testData WHERE key = 2"), Row(-2)
- )
-
- checkAnswer(
- sql("SELECT +key FROM testData WHERE key = 3"), Row(3)
- )
-
- checkAnswer(
- sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2)
- )
-
- checkAnswer(
- sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9)
- )
-
- checkAnswer(
- sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10)
- )
-
- checkAnswer(
- sql("SELECT -MAX(key) FROM testData"), Row(-100)
- )
-
- checkAnswer(
- sql("SELECT +MAX(key) FROM testData"), Row(100)
- )
-
- checkAnswer(
- sql("SELECT - (-10)"), Row(10)
- )
-
- checkAnswer(
- sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32)
- )
-
- checkAnswer(
- sql("SELECT - (+Max(key)) FROM testData"), Row(-100)
- )
-
- checkAnswer(
- sql("SELECT - - 3"), Row(3)
- )
-
- checkAnswer(
- sql("SELECT - + 20"), Row(-20)
- )
-
- checkAnswer(
- sql("SELEcT - + 45"), Row(-45)
- )
-
- checkAnswer(
- sql("SELECT + + 100"), Row(100)
- )
-
- checkAnswer(
- sql("SELECT - - Max(key) FROM testData"), Row(100)
- )
-
- checkAnswer(
- sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33)
- )
- }
-
test("Multiple join") {
checkAnswer(
sql(
@@ -1995,15 +1710,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
- test("SPARK-11032: resolve having correctly") {
- withTempView("src") {
- Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("src")
- checkAnswer(
- sql("SELECT MIN(t.i) FROM (SELECT * FROM src WHERE i > 0) t HAVING(COUNT(1) > 0)"),
- Row(1))
- }
- }
-
test("SPARK-11303: filter should not be pushed down into sample") {
val df = spark.range(100)
List(true, false).foreach { withReplacement =>
@@ -2503,70 +2209,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
- test("order by ordinal number") {
- checkAnswer(
- sql("SELECT * FROM testData2 ORDER BY 1 DESC"),
- sql("SELECT * FROM testData2 ORDER BY a DESC"))
- // If the position is not an integer, ignore it.
- checkAnswer(
- sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"),
- sql("SELECT * FROM testData2 ORDER BY b ASC"))
- checkAnswer(
- sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
- sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))
- checkAnswer(
- sql("SELECT * FROM testData2 SORT BY 1 DESC, 2"),
- sql("SELECT * FROM testData2 SORT BY a DESC, b ASC"))
- checkAnswer(
- sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"),
- Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
- }
-
- test("order by ordinal number - negative cases") {
- intercept[UnresolvedException[SortOrder]] {
- sql("SELECT * FROM testData2 ORDER BY 0")
- }
- intercept[UnresolvedException[SortOrder]] {
- sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC")
- }
- intercept[UnresolvedException[SortOrder]] {
- sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC")
- }
- }
-
- test("order by ordinal number with conf spark.sql.orderByOrdinal=false") {
- withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") {
- // If spark.sql.orderByOrdinal=false, ignore the position number.
- checkAnswer(
- sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
- sql("SELECT * FROM testData2 ORDER BY b ASC"))
- }
- }
-
- test("natural join") {
- val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1")
- val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2")
- withTempView("nt1", "nt2") {
- df1.createOrReplaceTempView("nt1")
- df2.createOrReplaceTempView("nt2")
- checkAnswer(
- sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""),
- Row("one", 1, 1) :: Row("one", 1, 5) :: Nil)
-
- checkAnswer(
- sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"),
- Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil)
-
- checkAnswer(
- sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"),
- Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil)
-
- checkAnswer(
- sql("SELECT count(*) FROM nt1 natural full outer join nt2"),
- Row(4) :: Nil)
- }
- }
-
test("join with using clause") {
val df1 = Seq(("r1c1", "r1c2", "t1r1c3"),
("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", "c2", "c3")
@@ -2950,6 +2592,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(s"$expected") :: Nil)
}
+ test("SPARK-16975: Column-partition path starting '_' should be handled correctly") {
+ withTempDir { dir =>
+ val parquetDir = new File(dir, "parquet").getCanonicalPath
+ spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir)
+ spark.read.parquet(parquetDir)
+ }
+ }
+
test("SPARK-16644: Aggregate should not put aggregate expressions to constraints") {
withTable("tbl") {
sql("CREATE TABLE tbl(a INT, b INT) USING parquet")
@@ -2981,13 +2631,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
data.selectExpr("`part.col1`", "`col.1`"))
}
}
-
- test("current_date and current_timestamp literals") {
- // NOTE that I am comparing the result of the literal with the result of the function call.
- // This is done to prevent the test from failing because we are comparing a result to an out
- // dated timestamp (quite likely) or date (very unlikely - but equally annoying).
- checkAnswer(
- sql("select current_date = current_date(), current_timestamp = current_timestamp()"),
- Seq(Row(true, true)))
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
new file mode 100644
index 0000000000000..069a9b665eb36
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.File
+import java.util.{Locale, TimeZone}
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.StructType
+
+/**
+ * End-to-end test cases for SQL queries.
+ *
+ * Each case is loaded from a file in "spark/sql/core/src/test/resources/sql-tests/inputs".
+ * Each case has a golden result file in "spark/sql/core/src/test/resources/sql-tests/results".
+ *
+ * To re-generate golden files, run:
+ * {{{
+ * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/test-only *SQLQueryTestSuite"
+ * }}}
+ *
+ * The format for input files is simple:
+ * 1. A list of SQL queries separated by semicolon.
+ * 2. Lines starting with -- are treated as comments and ignored.
+ *
+ * For example:
+ * {{{
+ * -- this is a comment
+ * select 1, -1;
+ * select current_date;
+ * }}}
+ *
+ * The format for golden result files look roughly like:
+ * {{{
+ * -- some header information
+ *
+ * -- !query 0
+ * select 1, -1
+ * -- !query 0 schema
+ * struct<...schema...>
+ * -- !query 0 output
+ * ... data row 1 ...
+ * ... data row 2 ...
+ * ...
+ *
+ * -- !query 1
+ * ...
+ * }}}
+ */
+class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
+
+ private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
+
+ private val baseResourcePath = {
+ // If regenerateGoldenFiles is true, we must be running this in SBT and we use hard-coded
+ // relative path. Otherwise, we use classloader's getResource to find the location.
+ if (regenerateGoldenFiles) {
+ java.nio.file.Paths.get("src", "test", "resources", "sql-tests").toFile
+ } else {
+ val res = getClass.getClassLoader.getResource("sql-tests")
+ new File(res.getFile)
+ }
+ }
+
+ private val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath
+ private val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath
+
+ /** List of test cases to ignore, in lower cases. */
+ private val blackList = Set(
+ "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality.
+ )
+
+ // Create all the test cases.
+ listTestCases().foreach(createScalaTestCase)
+
+ /** A test case. */
+ private case class TestCase(name: String, inputFile: String, resultFile: String)
+
+ /** A single SQL query's output. */
+ private case class QueryOutput(sql: String, schema: String, output: String) {
+ def toString(queryIndex: Int): String = {
+ // We are explicitly not using multi-line string due to stripMargin removing "|" in output.
+ s"-- !query $queryIndex\n" +
+ sql + "\n" +
+ s"-- !query $queryIndex schema\n" +
+ schema + "\n" +
+ s"-- !query $queryIndex output\n" +
+ output
+ }
+ }
+
+ private def createScalaTestCase(testCase: TestCase): Unit = {
+ if (blackList.contains(testCase.name.toLowerCase)) {
+ // Create a test case to ignore this case.
+ ignore(testCase.name) { /* Do nothing */ }
+ } else {
+ // Create a test case to run this case.
+ test(testCase.name) { runTest(testCase) }
+ }
+ }
+
+ /** Run a test case. */
+ private def runTest(testCase: TestCase): Unit = {
+ val input = fileToString(new File(testCase.inputFile))
+
+ // List of SQL queries to run
+ val queries: Seq[String] = {
+ val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n")
+ // note: this is not a robust way to split queries using semicolon, but works for now.
+ cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq
+ }
+
+ // Create a local SparkSession to have stronger isolation between different test cases.
+ // This does not isolate catalog changes.
+ val localSparkSession = spark.newSession()
+ loadTestData(localSparkSession)
+
+ // Run the SQL queries preparing them for comparison.
+ val outputs: Seq[QueryOutput] = queries.map { sql =>
+ val (schema, output) = getNormalizedResult(localSparkSession, sql)
+ // We might need to do some query canonicalization in the future.
+ QueryOutput(
+ sql = sql,
+ schema = schema.catalogString,
+ output = output.mkString("\n").trim)
+ }
+
+ if (regenerateGoldenFiles) {
+ // Again, we are explicitly not using multi-line string due to stripMargin removing "|".
+ val goldenOutput = {
+ s"-- Automatically generated by ${getClass.getSimpleName}\n" +
+ s"-- Number of queries: ${outputs.size}\n\n\n" +
+ outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n"
+ }
+ stringToFile(new File(testCase.resultFile), goldenOutput)
+ }
+
+ // Read back the golden file.
+ val expectedOutputs: Seq[QueryOutput] = {
+ val goldenOutput = fileToString(new File(testCase.resultFile))
+ val segments = goldenOutput.split("-- !query.+\n")
+
+ // each query has 3 segments, plus the header
+ assert(segments.size == outputs.size * 3 + 1,
+ s"Expected ${outputs.size * 3 + 1} blocks in result file but got ${segments.size}. " +
+ s"Try regenerate the result files.")
+ Seq.tabulate(outputs.size) { i =>
+ QueryOutput(
+ sql = segments(i * 3 + 1).trim,
+ schema = segments(i * 3 + 2).trim,
+ output = segments(i * 3 + 3).trim
+ )
+ }
+ }
+
+ // Compare results.
+ assertResult(expectedOutputs.size, s"Number of queries should be ${expectedOutputs.size}") {
+ outputs.size
+ }
+
+ outputs.zip(expectedOutputs).zipWithIndex.foreach { case ((output, expected), i) =>
+ assertResult(expected.sql, s"SQL query did not match for query #$i\n${expected.sql}") {
+ output.sql
+ }
+ assertResult(expected.schema, s"Schema did not match for query #$i\n${expected.sql}") {
+ output.schema
+ }
+ assertResult(expected.output, s"Result dit not match for query #$i\n${expected.sql}") {
+ output.output
+ }
+ }
+ }
+
+ /** Executes a query and returns the result as (schema of the output, normalized output). */
+ private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = {
+ // Returns true if the plan is supposed to be sorted.
+ def isSorted(plan: LogicalPlan): Boolean = plan match {
+ case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
+ case PhysicalOperation(_, _, Sort(_, true, _)) => true
+ case _ => plan.children.iterator.exists(isSorted)
+ }
+
+ try {
+ val df = session.sql(sql)
+ val schema = df.schema
+ val answer = df.queryExecution.hiveResultString()
+
+ // If the output is not pre-sorted, sort it.
+ if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
+
+ } catch {
+ case NonFatal(e) =>
+ // If there is an exception, put the exception class followed by the message.
+ (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))
+ }
+ }
+
+ private def listTestCases(): Seq[TestCase] = {
+ listFilesRecursively(new File(inputFilePath)).map { file =>
+ val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out"
+ TestCase(file.getName, file.getAbsolutePath, resultFile)
+ }
+ }
+
+ /** Returns all the files (not directories) in a directory, recursively. */
+ private def listFilesRecursively(path: File): Seq[File] = {
+ val (dirs, files) = path.listFiles().partition(_.isDirectory)
+ files ++ dirs.flatMap(listFilesRecursively)
+ }
+
+ /** Load built-in test tables into the SparkSession. */
+ private def loadTestData(session: SparkSession): Unit = {
+ import session.implicits._
+
+ (1 to 100).map(i => (i, i.toString)).toDF("key", "value").createOrReplaceTempView("testdata")
+
+ ((Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: (Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
+ .toDF("arraycol", "nestedarraycol")
+ .createOrReplaceTempView("arraydata")
+
+ (Tuple1(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+ Tuple1(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+ Tuple1(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+ Tuple1(Map(1 -> "a4", 2 -> "b4")) ::
+ Tuple1(Map(1 -> "a5")) :: Nil)
+ .toDF("mapcol")
+ .createOrReplaceTempView("mapdata")
+ }
+
+ private val originalTimeZone = TimeZone.getDefault
+ private val originalLocale = Locale.getDefault
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
+ TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
+ // Add Locale setting
+ Locale.setDefault(Locale.US)
+ RuleExecutor.resetTime()
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ TimeZone.setDefault(originalTimeZone)
+ Locale.setDefault(originalLocale)
+
+ // For debugging dump some statistics about how much time was spent in various optimizer rules
+ logWarning(RuleExecutor.dumpTimeSpent())
+ } finally {
+ super.afterAll()
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
index 418345b9ee8f2..386d13d07a95f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
@@ -100,6 +100,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
assert(session.conf.get("key2") == "value2")
assert(session.sparkContext.conf.get("key1") == "value1")
assert(session.sparkContext.conf.get("key2") == "value2")
+ assert(session.sparkContext.conf.get("spark.app.name") == "test")
session.stop()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 6edd7b0c25b97..9be2de9c7d719 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -94,6 +94,10 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
test("non-matching optional group") {
val df = Seq(Tuple1("aaaac")).toDF("s")
+ checkAnswer(
+ df.select(regexp_extract($"s", "(foo)", 1)),
+ Row("")
+ )
checkAnswer(
df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)),
Row("")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
index 7b96f4c99ab5a..8d74884df9273 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
@@ -564,6 +564,14 @@ class DDLCommandSuite extends PlanTest {
comparePlans(parsed2, expected2)
}
+ test("alter table: recover partitions") {
+ val sql = "ALTER TABLE table_name RECOVER PARTITIONS"
+ val parsed = parser.parsePlan(sql)
+ val expected = AlterTableRecoverPartitionsCommand(
+ TableIdentifier("table_name", None))
+ comparePlans(parsed, expected)
+ }
+
test("alter view: add partition (not supported)") {
assertUnsupported(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index f2ec393c30eca..d70cae74bc6c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -111,10 +111,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false)
}
- private def appendTrailingSlash(path: String): String = {
- if (!path.endsWith(File.separator)) path + File.separator else path
- }
-
test("the qualified path of a database is stored in the catalog") {
val catalog = spark.sessionState.catalog
@@ -122,18 +118,19 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val path = tmpDir.toString
// The generated temp path is not qualified.
assert(!path.startsWith("file:/"))
- sql(s"CREATE DATABASE db1 LOCATION '$path'")
+ val uri = tmpDir.toURI
+ sql(s"CREATE DATABASE db1 LOCATION '$uri'")
val pathInCatalog = new Path(catalog.getDatabaseMetadata("db1").locationUri).toUri
assert("file" === pathInCatalog.getScheme)
- val expectedPath = if (path.endsWith(File.separator)) path.dropRight(1) else path
- assert(expectedPath === pathInCatalog.getPath)
+ val expectedPath = new Path(path).toUri
+ assert(expectedPath.getPath === pathInCatalog.getPath)
withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) {
sql(s"CREATE DATABASE db2")
- val pathInCatalog = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri
- assert("file" === pathInCatalog.getScheme)
- val expectedPath = appendTrailingSlash(spark.sessionState.conf.warehousePath) + "db2.db"
- assert(expectedPath === pathInCatalog.getPath)
+ val pathInCatalog2 = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri
+ assert("file" === pathInCatalog2.getScheme)
+ val expectedPath2 = new Path(spark.sessionState.conf.warehousePath + "/" + "db2.db").toUri
+ assert(expectedPath2.getPath === pathInCatalog2.getPath)
}
sql("DROP DATABASE db1")
@@ -141,6 +138,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
+ private def makeQualifiedPath(path: String): String = {
+ // copy-paste from SessionCatalog
+ val hadoopPath = new Path(path)
+ val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration)
+ fs.makeQualified(hadoopPath).toString
+ }
+
test("Create/Drop Database") {
withTempDir { tmpDir =>
val path = tmpDir.toString
@@ -154,8 +158,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql(s"CREATE DATABASE $dbName")
val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks)
- val expectedLocation =
- "file:" + appendTrailingSlash(path) + s"$dbNameWithoutBackTicks.db"
+ val expectedLocation = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db")
assert(db1 == CatalogDatabase(
dbNameWithoutBackTicks,
"",
@@ -181,8 +184,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql(s"CREATE DATABASE $dbName")
val db1 = catalog.getDatabaseMetadata(dbName)
val expectedLocation =
- "file:" + appendTrailingSlash(System.getProperty("user.dir")) +
- s"spark-warehouse/$dbName.db"
+ makeQualifiedPath(s"${System.getProperty("user.dir")}/spark-warehouse" +
+ "/" + s"$dbName.db")
assert(db1 == CatalogDatabase(
dbName,
"",
@@ -200,17 +203,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val catalog = spark.sessionState.catalog
val databaseNames = Seq("db1", "`database`")
withTempDir { tmpDir =>
- val path = tmpDir.toString
- val dbPath = "file:" + path
+ val path = new Path(tmpDir.toString).toUri.toString
databaseNames.foreach { dbName =>
try {
val dbNameWithoutBackTicks = cleanIdentifier(dbName)
sql(s"CREATE DATABASE $dbName Location '$path'")
val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks)
+ val expPath = makeQualifiedPath(tmpDir.toString)
assert(db1 == CatalogDatabase(
dbNameWithoutBackTicks,
"",
- if (dbPath.endsWith(File.separator)) dbPath.dropRight(1) else dbPath,
+ expPath,
Map.empty))
sql(s"DROP DATABASE $dbName CASCADE")
assert(!catalog.databaseExists(dbNameWithoutBackTicks))
@@ -233,8 +236,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val dbNameWithoutBackTicks = cleanIdentifier(dbName)
sql(s"CREATE DATABASE $dbName")
val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks)
- val expectedLocation =
- "file:" + appendTrailingSlash(path) + s"$dbNameWithoutBackTicks.db"
+ val expectedLocation = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db")
assert(db1 == CatalogDatabase(
dbNameWithoutBackTicks,
"",
@@ -275,7 +277,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
databaseNames.foreach { dbName =>
try {
val dbNameWithoutBackTicks = cleanIdentifier(dbName)
- val location = "file:" + appendTrailingSlash(path) + s"$dbNameWithoutBackTicks.db"
+ val location = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db")
sql(s"CREATE DATABASE $dbName")
@@ -436,7 +438,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("create temporary view using") {
- val csvFile = Thread.currentThread().getContextClassLoader.getResource("cars.csv").toString()
+ val csvFile =
+ Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString
withView("testview") {
sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1: String, c2: String) USING " +
"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " +
@@ -628,6 +631,55 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
testAddPartitions(isDatasourceTable = true)
}
+ test("alter table: recover partitions (sequential)") {
+ withSQLConf("spark.rdd.parallelListingThreshold" -> "1") {
+ testRecoverPartitions()
+ }
+ }
+
+ test("alter table: recover partition (parallel)") {
+ withSQLConf("spark.rdd.parallelListingThreshold" -> "10") {
+ testRecoverPartitions()
+ }
+ }
+
+ private def testRecoverPartitions() {
+ val catalog = spark.sessionState.catalog
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist RECOVER PARTITIONS")
+ }
+
+ val tableIdent = TableIdentifier("tab1")
+ createTable(catalog, tableIdent)
+ val part1 = Map("a" -> "1", "b" -> "5")
+ createTablePartition(catalog, part1, tableIdent)
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1))
+
+ val part2 = Map("a" -> "2", "b" -> "6")
+ val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get)
+ val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration)
+ // valid
+ fs.mkdirs(new Path(new Path(root, "a=1"), "b=5"))
+ fs.mkdirs(new Path(new Path(root, "A=2"), "B=6"))
+ // invalid
+ fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name
+ fs.mkdirs(new Path(new Path(root, "b=1"), "a=1")) // wrong order
+ fs.mkdirs(new Path(root, "a=4")) // not enough columns
+ fs.createNewFile(new Path(new Path(root, "a=1"), "b=4")) // file
+ fs.createNewFile(new Path(new Path(root, "a=1"), "_SUCCESS")) // _SUCCESS
+ fs.mkdirs(new Path(new Path(root, "a=1"), "_temporary")) // _temporary
+ fs.mkdirs(new Path(new Path(root, "a=1"), ".b=4")) // start with .
+
+ try {
+ sql("ALTER TABLE tab1 RECOVER PARTITIONS")
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(part1, part2))
+ } finally {
+ fs.delete(root, true)
+ }
+ }
+
test("alter table: add partition is not supported for views") {
assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 311f1fa8d2aff..8cd76ddf20f04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -33,23 +33,23 @@ import org.apache.spark.sql.types._
class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
import testImplicits._
- private val carsFile = "cars.csv"
- private val carsMalformedFile = "cars-malformed.csv"
- private val carsFile8859 = "cars_iso-8859-1.csv"
- private val carsTsvFile = "cars.tsv"
- private val carsAltFile = "cars-alternative.csv"
- private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv"
- private val carsNullFile = "cars-null.csv"
- private val carsBlankColName = "cars-blank-column-name.csv"
- private val emptyFile = "empty.csv"
- private val commentsFile = "comments.csv"
- private val disableCommentsFile = "disable_comments.csv"
- private val boolFile = "bool.csv"
- private val decimalFile = "decimal.csv"
- private val simpleSparseFile = "simple_sparse.csv"
- private val numbersFile = "numbers.csv"
- private val datesFile = "dates.csv"
- private val unescapedQuotesFile = "unescaped-quotes.csv"
+ private val carsFile = "test-data/cars.csv"
+ private val carsMalformedFile = "test-data/cars-malformed.csv"
+ private val carsFile8859 = "test-data/cars_iso-8859-1.csv"
+ private val carsTsvFile = "test-data/cars.tsv"
+ private val carsAltFile = "test-data/cars-alternative.csv"
+ private val carsUnbalancedQuotesFile = "test-data/cars-unbalanced-quotes.csv"
+ private val carsNullFile = "test-data/cars-null.csv"
+ private val carsBlankColName = "test-data/cars-blank-column-name.csv"
+ private val emptyFile = "test-data/empty.csv"
+ private val commentsFile = "test-data/comments.csv"
+ private val disableCommentsFile = "test-data/disable_comments.csv"
+ private val boolFile = "test-data/bool.csv"
+ private val decimalFile = "test-data/decimal.csv"
+ private val simpleSparseFile = "test-data/simple_sparse.csv"
+ private val numbersFile = "test-data/numbers.csv"
+ private val datesFile = "test-data/dates.csv"
+ private val unescapedQuotesFile = "test-data/unescaped-quotes.csv"
private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 2a89773cf5341..ab9250045f5b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
+import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -370,73 +371,75 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") {
import testImplicits._
-
- withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
- SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") {
- withTempPath { dir =>
- val pathOne = s"${dir.getCanonicalPath}/table1"
- (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne)
- val pathTwo = s"${dir.getCanonicalPath}/table2"
- (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo)
-
- // If the "c = 1" filter gets pushed down, this query will throw an exception which
- // Parquet emits. This is a Parquet issue (PARQUET-389).
- val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a")
- checkAnswer(
- df,
- Row(1, "1", null))
-
- // The fields "a" and "c" only exist in one Parquet file.
- assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
- assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
-
- val pathThree = s"${dir.getCanonicalPath}/table3"
- df.write.parquet(pathThree)
-
- // We will remove the temporary metadata when writing Parquet file.
- val schema = spark.read.parquet(pathThree).schema
- assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
-
- val pathFour = s"${dir.getCanonicalPath}/table4"
- val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
- dfStruct.select(struct("a").as("s")).write.parquet(pathFour)
-
- val pathFive = s"${dir.getCanonicalPath}/table5"
- val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b")
- dfStruct2.select(struct("c").as("s")).write.parquet(pathFive)
-
- // If the "s.c = 1" filter gets pushed down, this query will throw an exception which
- // Parquet emits.
- val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1")
- .selectExpr("s")
- checkAnswer(dfStruct3, Row(Row(null, 1)))
-
- // The fields "s.a" and "s.c" only exist in one Parquet file.
- val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType]
- assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
- assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
-
- val pathSix = s"${dir.getCanonicalPath}/table6"
- dfStruct3.write.parquet(pathSix)
-
- // We will remove the temporary metadata when writing Parquet file.
- val forPathSix = spark.read.parquet(pathSix).schema
- assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
-
- // sanity test: make sure optional metadata field is not wrongly set.
- val pathSeven = s"${dir.getCanonicalPath}/table7"
- (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven)
- val pathEight = s"${dir.getCanonicalPath}/table8"
- (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight)
-
- val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b")
- checkAnswer(
- df2,
- Row(1, "1"))
-
- // The fields "a" and "b" exist in both two Parquet files. No metadata is set.
- assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField))
- assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField))
+ Seq("true", "false").map { vectorized =>
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
+ SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true",
+ SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
+ withTempPath { dir =>
+ val pathOne = s"${dir.getCanonicalPath}/table1"
+ (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne)
+ val pathTwo = s"${dir.getCanonicalPath}/table2"
+ (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo)
+
+ // If the "c = 1" filter gets pushed down, this query will throw an exception which
+ // Parquet emits. This is a Parquet issue (PARQUET-389).
+ val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a")
+ checkAnswer(
+ df,
+ Row(1, "1", null))
+
+ // The fields "a" and "c" only exist in one Parquet file.
+ assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+
+ val pathThree = s"${dir.getCanonicalPath}/table3"
+ df.write.parquet(pathThree)
+
+ // We will remove the temporary metadata when writing Parquet file.
+ val schema = spark.read.parquet(pathThree).schema
+ assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
+
+ val pathFour = s"${dir.getCanonicalPath}/table4"
+ val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ dfStruct.select(struct("a").as("s")).write.parquet(pathFour)
+
+ val pathFive = s"${dir.getCanonicalPath}/table5"
+ val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b")
+ dfStruct2.select(struct("c").as("s")).write.parquet(pathFive)
+
+ // If the "s.c = 1" filter gets pushed down, this query will throw an exception which
+ // Parquet emits.
+ val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1")
+ .selectExpr("s")
+ checkAnswer(dfStruct3, Row(Row(null, 1)))
+
+ // The fields "s.a" and "s.c" only exist in one Parquet file.
+ val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType]
+ assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+
+ val pathSix = s"${dir.getCanonicalPath}/table6"
+ dfStruct3.write.parquet(pathSix)
+
+ // We will remove the temporary metadata when writing Parquet file.
+ val forPathSix = spark.read.parquet(pathSix).schema
+ assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
+
+ // sanity test: make sure optional metadata field is not wrongly set.
+ val pathSeven = s"${dir.getCanonicalPath}/table7"
+ (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven)
+ val pathEight = s"${dir.getCanonicalPath}/table8"
+ (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight)
+
+ val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b")
+ checkAnswer(
+ df2,
+ Row(1, "1"))
+
+ // The fields "a" and "b" exist in both two Parquet files. No metadata is set.
+ assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField))
+ assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField))
+ }
}
}
}
@@ -559,4 +562,32 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
assert(df.filter("_1 IS NOT NULL").count() === 4)
}
}
+
+ test("Fiters should be pushed down for vectorized Parquet reader at row group level") {
+ import testImplicits._
+
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true",
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/table"
+ (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)
+
+ Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
+ val accu = new LongAccumulator
+ accu.register(sparkContext, Some("numRowGroups"))
+
+ val df = spark.read.parquet(path).filter("a < 100")
+ df.foreachPartition(_.foreach(v => accu.add(0)))
+ df.collect
+
+ val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
+ assert(numRowGroups.isDefined)
+ assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
+ AccumulatorContext.remove(accu.id)
+ }
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index fc9ce6bb3041b..a95de2ea9135f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -556,7 +556,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-i32.parquet"),
+ readResourceParquetFile("test-data/dec-in-i32.parquet"),
spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
}
}
@@ -567,7 +567,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-i64.parquet"),
+ readResourceParquetFile("test-data/dec-in-i64.parquet"),
spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
}
}
@@ -578,7 +578,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary
- readResourceParquetFile("dec-in-fixed-len.parquet"),
+ readResourceParquetFile("test-data/dec-in-fixed-len.parquet"),
spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
index 98333e58cada8..fa88019298a69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
@@ -22,12 +22,12 @@ import org.apache.spark.sql.test.SharedSQLContext
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
test("unannotated array of primitive type") {
- checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
+ checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
}
test("unannotated array of struct") {
checkAnswer(
- readResourceParquetFile("old-repeated-message.parquet"),
+ readResourceParquetFile("test-data/old-repeated-message.parquet"),
Row(
Seq(
Row("First inner", null, null),
@@ -35,14 +35,14 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh
Row(null, null, "Third inner"))))
checkAnswer(
- readResourceParquetFile("proto-repeated-struct.parquet"),
+ readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
Row(
Seq(
Row("0 - 1", "0 - 2", "0 - 3"),
Row("1 - 1", "1 - 2", "1 - 3"))))
checkAnswer(
- readResourceParquetFile("proto-struct-with-array-many.parquet"),
+ readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
Seq(
Row(
Seq(
@@ -60,13 +60,13 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh
test("struct with unannotated array") {
checkAnswer(
- readResourceParquetFile("proto-struct-with-array.parquet"),
+ readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
}
test("unannotated array of struct with unannotated array") {
checkAnswer(
- readResourceParquetFile("nested-array-struct.parquet"),
+ readResourceParquetFile("test-data/nested-array-struct.parquet"),
Seq(
Row(2, Seq(Row(1, Seq(Row(3))))),
Row(5, Seq(Row(4, Seq(Row(6))))),
@@ -75,7 +75,7 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh
test("unannotated array of string") {
checkAnswer(
- readResourceParquetFile("proto-repeated-string.parquet"),
+ readResourceParquetFile("test-data/proto-repeated-string.parquet"),
Seq(
Row(Seq("hello", "world")),
Row(Seq("good", "bye")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index ff5706999a6dd..4157a5b46dc42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -23,8 +23,8 @@ import org.apache.spark.sql.test.SharedSQLContext
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
- private val parquetFilePath =
- Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")
+ private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource(
+ "test-data/parquet-thrift-compat.snappy.parquet")
test("Read Parquet file generated by parquet-thrift") {
logInfo(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
index 71d3da915840a..d11c2acb815d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -66,7 +66,7 @@ class TextSuite extends QueryTest with SharedSQLContext {
test("reading partitioned data using read.textFile()") {
val partitionedData = Thread.currentThread().getContextClassLoader
- .getResource("text-partitioned").toString
+ .getResource("test-data/text-partitioned").toString
val ds = spark.read.textFile(partitionedData)
val data = ds.collect()
@@ -76,7 +76,7 @@ class TextSuite extends QueryTest with SharedSQLContext {
test("support for partitioned reading using read.text()") {
val partitionedData = Thread.currentThread().getContextClassLoader
- .getResource("text-partitioned").toString
+ .getResource("test-data/text-partitioned").toString
val df = spark.read.text(partitionedData)
val data = df.filter("year = '2015'").select("value").collect()
@@ -155,7 +155,7 @@ class TextSuite extends QueryTest with SharedSQLContext {
}
private def testFile: String = {
- Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
+ Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString
}
/** Verifies data and schema. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
new file mode 100644
index 0000000000000..d826d3f54d922
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+class ReduceAggregatorSuite extends SparkFunSuite {
+
+ test("zero value") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+ assert(aggregator.zero == (false, null))
+ }
+
+ test("reduce, merge and finish") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ val firstReduce = aggregator.reduce(aggregator.zero, 1)
+ assert(firstReduce == (true, 1))
+
+ val secondReduce = aggregator.reduce(firstReduce, 2)
+ assert(secondReduce == (true, 3))
+
+ val thirdReduce = aggregator.reduce(secondReduce, 3)
+ assert(thirdReduce == (true, 6))
+
+ val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
+ assert(mergeWithZero1 == (true, 1))
+
+ val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
+ assert(mergeWithZero2 == (true, 3))
+
+ val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
+ assert(mergeTwoReduced == (true, 4))
+
+ assert(aggregator.finish(firstReduce)== 1)
+ assert(aggregator.finish(secondReduce) == 3)
+ assert(aggregator.finish(thirdReduce) == 6)
+ assert(aggregator.finish(mergeWithZero1) == 1)
+ assert(aggregator.finish(mergeWithZero2) == 3)
+ assert(aggregator.finish(mergeTwoReduced) == 4)
+ }
+
+ test("requires at least one input row") {
+ val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+ val func = (v1: Int, v2: Int) => v1 + v2
+ val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+ intercept[IllegalStateException] {
+ aggregator.finish(aggregator.zero)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 5d348044515af..761bbe3576c71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.internal
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
@@ -214,7 +216,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
// to get the default value, always unset it
spark.conf.unset(SQLConf.WAREHOUSE_PATH.key)
assert(spark.sessionState.conf.warehousePath
- === s"file:${System.getProperty("user.dir")}/spark-warehouse")
+ === new Path(s"${System.getProperty("user.dir")}/spark-warehouse").toString)
} finally {
sql(s"set ${SQLConf.WAREHOUSE_PATH}=$original")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index 7f4d28cf0598f..77602e8167fa3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -94,7 +94,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(status.id === query.id)
assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString))
assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString)
- assert(listener.terminationStackTrace.isEmpty)
assert(listener.terminationException === None)
}
listener.checkAsyncErrors()
@@ -147,7 +146,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}
- test("exception should be reported in QueryTerminated") {
+ testQuietly("exception should be reported in QueryTerminated") {
val listener = new QueryStatusCollector
withListenerAdded(listener) {
val input = MemoryStream[Int]
@@ -159,8 +158,11 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
assert(listener.terminationStatus !== null)
assert(listener.terminationException.isDefined)
+ // Make sure that the exception message reported through listener
+ // contains the actual exception and relevant stack trace
+ assert(!listener.terminationException.get.contains("StreamingQueryException"))
assert(listener.terminationException.get.contains("java.lang.ArithmeticException"))
- assert(listener.terminationStackTrace.nonEmpty)
+ assert(listener.terminationException.get.contains("StreamingQueryListenerSuite"))
}
)
}
@@ -205,8 +207,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
val exception = new RuntimeException("exception")
val queryQueryTerminated = new StreamingQueryListener.QueryTerminated(
queryTerminatedInfo,
- Some(exception.getMessage),
- exception.getStackTrace)
+ Some(exception.getMessage))
val json =
JsonProtocol.sparkEventToJson(queryQueryTerminated)
val newQueryTerminated = JsonProtocol.sparkEventFromJson(json)
@@ -262,7 +263,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
@volatile var startStatus: StreamingQueryInfo = null
@volatile var terminationStatus: StreamingQueryInfo = null
@volatile var terminationException: Option[String] = null
- @volatile var terminationStackTrace: Seq[StackTraceElement] = null
val progressStatuses = new ConcurrentLinkedQueue[StreamingQueryInfo]
@@ -296,7 +296,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(startStatus != null, "onQueryTerminated called before onQueryStarted")
terminationStatus = queryTerminated.queryInfo
terminationException = queryTerminated.exception
- terminationStackTrace = queryTerminated.stackTrace
}
asyncTestWaiter.dismiss()
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index 15a5d79dcb085..3a8b0f1b8ebdf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.hive.MetastoreRelation
* @param ignoreIfExists allow continue working if it's already exists, otherwise
* raise exception
*/
-private[hive]
case class CreateHiveTableAsSelectCommand(
tableDesc: CatalogTable,
query: LogicalPlan,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index cc3e74b4e8ccc..a716a3eab6219 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -54,7 +54,7 @@ case class HiveTableScanExec(
require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
- private[sql] override lazy val metrics = Map(
+ override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = outputSet ++
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index dfb12512a40fc..9747abbf15a55 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -51,7 +51,6 @@ import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfig
* @param script the command that should be executed.
* @param output the attributes that are produced by the script.
*/
-private[hive]
case class ScriptTransformation(
input: Seq[Expression],
script: String,
@@ -336,7 +335,6 @@ private class ScriptTransformationWriterThread(
}
}
-private[hive]
object HiveScriptIOSchema {
def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = {
HiveScriptIOSchema(
@@ -355,7 +353,6 @@ object HiveScriptIOSchema {
/**
* The wrapper class of Hive input and output schema properties
*/
-private[hive]
case class HiveScriptIOSchema (
inputRowFormat: Seq[(String, String)],
outputRowFormat: Seq[(String, String)],
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index a2c8092e01bb9..9843f0774af96 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -47,7 +47,7 @@ import org.apache.spark.util.SerializableConfiguration
* [[FileFormat]] for reading ORC files. If this is moved or renamed, please update
* [[DataSource]]'s backwardCompatibilityMap.
*/
-private[sql] class OrcFileFormat
+class OrcFileFormat
extends FileFormat with DataSourceRegister with Serializable {
override def shortName(): String = "orc"
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
index fef726c5d801d..7249df813b17f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
@@ -75,8 +75,8 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
checkSQL('a.int / 'b.int, "(`a` / `b`)")
checkSQL('a.int % 'b.int, "(`a` % `b`)")
- checkSQL(-'a.int, "(-`a`)")
- checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))")
+ checkSQL(-'a.int, "(- `a`)")
+ checkSQL(-('a.int + 'b.int), "(- (`a` + `b`))")
}
test("window specification") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index d8ab864ca6fce..4e5a51155defd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -41,8 +41,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
import testImplicits._
// Used for generating new query answer files by saving
- private val regenerateGoldenFiles: Boolean =
- Option(System.getenv("SPARK_GENERATE_GOLDEN_FILES")) == Some("1")
+ private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
private val goldenSQLPath = "src/test/resources/sqlgen/"
protected override def beforeAll(): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
index 867aadb5f5569..54009d4b4130a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -520,8 +521,13 @@ class HiveDDLCommandSuite extends PlanTest {
}
}
- test("MSCK repair table (not supported)") {
- assertUnsupported("MSCK REPAIR TABLE tab1")
+ test("MSCK REPAIR table") {
+ val sql = "MSCK REPAIR TABLE tab1"
+ val parsed = parser.parsePlan(sql)
+ val expected = AlterTableRecoverPartitionsCommand(
+ TableIdentifier("tab1", None),
+ "MSCK REPAIR TABLE")
+ comparePlans(parsed, expected)
}
test("create table like") {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala
index 9697437dd2fe5..0b306a28d1a59 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala
@@ -87,11 +87,11 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source {
// Gauge for last received batch, useful for monitoring the streaming job's running status,
// displayed data -1 for any abnormal condition.
registerGaugeWithOption("lastReceivedBatch_submissionTime",
- _.lastCompletedBatch.map(_.submissionTime), -1L)
+ _.lastReceivedBatch.map(_.submissionTime), -1L)
registerGaugeWithOption("lastReceivedBatch_processingStartTime",
- _.lastCompletedBatch.flatMap(_.processingStartTime), -1L)
+ _.lastReceivedBatch.flatMap(_.processingStartTime), -1L)
registerGaugeWithOption("lastReceivedBatch_processingEndTime",
- _.lastCompletedBatch.flatMap(_.processingEndTime), -1L)
+ _.lastReceivedBatch.flatMap(_.processingEndTime), -1L)
// Gauge for last received batch records.
registerGauge("lastReceivedBatch_records", _.lastReceivedBatchRecords.values.sum, 0L)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index 26b757cc2d535..46ab3ac8de3d4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -68,6 +68,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted)))
listener.runningBatches should be (Nil)
listener.retainedCompletedBatches should be (Nil)
+ listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoSubmitted)))
listener.lastCompletedBatch should be (None)
listener.numUnprocessedBatches should be (1)
listener.numTotalCompletedBatches should be (0)
@@ -81,6 +82,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.waitingBatches should be (Nil)
listener.runningBatches should be (List(BatchUIData(batchInfoStarted)))
listener.retainedCompletedBatches should be (Nil)
+ listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoStarted)))
listener.lastCompletedBatch should be (None)
listener.numUnprocessedBatches should be (1)
listener.numTotalCompletedBatches should be (0)
@@ -123,6 +125,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
listener.waitingBatches should be (Nil)
listener.runningBatches should be (Nil)
listener.retainedCompletedBatches should be (List(BatchUIData(batchInfoCompleted)))
+ listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoCompleted)))
listener.lastCompletedBatch should be (Some(BatchUIData(batchInfoCompleted)))
listener.numUnprocessedBatches should be (0)
listener.numTotalCompletedBatches should be (1)
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 6b3c831e60472..ea63ff5dc1580 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -125,8 +125,20 @@ private[spark] abstract class YarnSchedulerBackend(
* This includes executors already pending or running.
*/
override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- yarnSchedulerEndpointRef.askWithRetry[Boolean](
- RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
+ val r = RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)
+ yarnSchedulerEndpoint.amEndpoint match {
+ case Some(am) =>
+ try {
+ am.askWithRetry[Boolean](r)
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Sending $r to AM was unsuccessful", e)
+ return false
+ }
+ case None =>
+ logWarning("Attempted to request executors before the AM has registered!")
+ return false
+ }
}
/**
@@ -209,7 +221,7 @@ private[spark] abstract class YarnSchedulerBackend(
*/
private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
- private var amEndpoint: Option[RpcEndpointRef] = None
+ var amEndpoint: Option[RpcEndpointRef] = None
private val askAmThreadPool =
ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool")