diff --git a/LICENSE b/LICENSE
index 7950dd6ceb6db..b948ccaeecea6 100644
--- a/LICENSE
+++ b/LICENSE
@@ -249,11 +249,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(Interpreter classes (all .scala files in repl/src/main/scala
except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala),
and for SerializableMapWrapper in JavaUtils.scala)
- (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/)
- (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/)
- (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/)
- (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/)
- (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/)
+ (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/)
+ (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/)
+ (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/)
+ (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/)
+ (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/)
(BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org)
(BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org)
(BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org)
@@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
- (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/)
+ (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
@@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(MIT License) RowsGroup (http://datatables.net/license/mit)
(MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html)
(MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE)
+ (MIT License) machinist (https://github.com/typelevel/machinist)
diff --git a/NOTICE b/NOTICE
index f4b64b5c3f470..737189af09bf7 100644
--- a/NOTICE
+++ b/NOTICE
@@ -5,6 +5,31 @@ This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).
+Export Control Notice
+---------------------
+
+This distribution includes cryptographic software. The country in which you currently reside may have
+restrictions on the import, possession, use, and/or re-export to another country, of encryption software.
+BEFORE using any encryption software, please check your country's laws, regulations and policies concerning
+the import, possession, or use, and re-export of encryption software, to see if this is permitted. See
+ for more information.
+
+The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this
+software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software
+using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache
+Software Foundation distribution makes it eligible for export under the License Exception ENC Technology
+Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for
+both object code and source code.
+
+The following provides more details on the included cryptographic software:
+
+This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to
+support authentication, and encryption and decryption of data sent across the network between
+services.
+
+This software includes Bouncy Castle (http://bouncycastle.org/) to support the jets3t library.
+
+
========================================================================
Common Development and Distribution License 1.0
========================================================================
diff --git a/R/README.md b/R/README.md
index 4c40c5963db70..1152b1e8e5f9f 100644
--- a/R/README.md
+++ b/R/README.md
@@ -66,11 +66,7 @@ To run one of them, use `./bin/spark-submit `. For example:
```bash
./bin/spark-submit examples/src/main/r/dataframe.R
```
-You can also run the unit tests for SparkR by running. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first:
-```bash
-R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")'
-./R/run-tests.sh
-```
+You can run R unit tests by following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests).
### Running on YARN
diff --git a/R/WINDOWS.md b/R/WINDOWS.md
index 9ca7e58e20cd2..124bc631be9cd 100644
--- a/R/WINDOWS.md
+++ b/R/WINDOWS.md
@@ -34,10 +34,9 @@ To run the SparkR unit tests on Windows, the following steps are required —ass
4. Set the environment variable `HADOOP_HOME` to the full path to the newly created `hadoop` directory.
-5. Run unit tests for SparkR by running the command below. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first:
+5. Run unit tests for SparkR by running the command below. You need to install the needed packages following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests) first:
```
- R -e "install.packages('testthat', repos='http://cran.us.r-project.org')"
.\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R
```
diff --git a/R/install-dev.sh b/R/install-dev.sh
index d613552718307..9fbc999f2e805 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -28,6 +28,7 @@
set -o pipefail
set -e
+set -x
FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)"
LIB_DIR="$FWDIR/lib"
diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore
index f12f8c275a989..18b2db69db8f1 100644
--- a/R/pkg/.Rbuildignore
+++ b/R/pkg/.Rbuildignore
@@ -6,3 +6,4 @@
^README\.Rmd$
^src-native$
^html$
+^tests/fulltests/*
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 879c1f80f2c5d..ad723300490f1 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -1,8 +1,8 @@
Package: SparkR
Type: Package
-Version: 2.2.0
+Version: 2.2.3
Title: R Frontend for Apache Spark
-Description: The SparkR package provides an R Frontend for Apache Spark.
+Description: Provides an R Frontend for Apache Spark.
Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"),
email = "shivaram@cs.berkeley.edu"),
person("Xiangrui", "Meng", role = "aut",
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index ca45c6f9b0a96..44e39c4abb472 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -122,6 +122,7 @@ exportMethods("arrange",
"group_by",
"groupBy",
"head",
+ "hint",
"insertInto",
"intersect",
"isLocal",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 88a138fd8eb1f..22e62532b9b5d 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -591,7 +591,7 @@ setMethod("cache",
#'
#' Persist this SparkDataFrame with the specified storage level. For details of the
#' supported storage levels, refer to
-#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}.
+#' \url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}.
#'
#' @param x the SparkDataFrame to persist.
#' @param newLevel storage level chosen for the persistance. See available options in
@@ -1174,6 +1174,9 @@ setMethod("collect",
vec <- do.call(c, col)
stopifnot(class(vec) != "list")
class(vec) <- PRIMITIVE_TYPES[[colType]]
+ if (is.character(vec) && stringsAsFactors) {
+ vec <- as.factor(vec)
+ }
df[[colIndex]] <- vec
} else {
df[[colIndex]] <- col
@@ -2642,6 +2645,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
#' Input SparkDataFrames can have different schemas (names and data types).
#'
#' Note: This does not remove duplicate rows across the two SparkDataFrames.
+#' Also as standard in SQL, this function resolves columns by position (not by name).
#'
#' @param x A SparkDataFrame
#' @param y A SparkDataFrame
@@ -3114,7 +3118,7 @@ setMethod("as.data.frame",
#'
#' @family SparkDataFrame functions
#' @rdname attach
-#' @aliases attach,SparkDataFrame-method
+#' @aliases attach attach,SparkDataFrame-method
#' @param what (SparkDataFrame) The SparkDataFrame to attach
#' @param pos (integer) Specify position in search() where to attach.
#' @param name (character) Name to use for the attached SparkDataFrame. Names
@@ -3130,9 +3134,12 @@ setMethod("as.data.frame",
#' @note attach since 1.6.0
setMethod("attach",
signature(what = "SparkDataFrame"),
- function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) {
- newEnv <- assignNewEnv(what)
- attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts)
+ function(what, pos = 2L, name = deparse(substitute(what), backtick = FALSE),
+ warn.conflicts = TRUE) {
+ args <- as.list(environment()) # capture all parameters - this must be the first line
+ newEnv <- assignNewEnv(args$what)
+ args$what <- newEnv
+ do.call(attach, args)
})
#' Evaluate a R expression in an environment constructed from a SparkDataFrame
@@ -3642,3 +3649,33 @@ setMethod("checkpoint",
df <- callJMethod(x@sdf, "checkpoint", as.logical(eager))
dataFrame(df)
})
+
+#' hint
+#'
+#' Specifies execution plan hint and return a new SparkDataFrame.
+#'
+#' @param x a SparkDataFrame.
+#' @param name a name of the hint.
+#' @param ... optional parameters for the hint.
+#' @return A SparkDataFrame.
+#' @family SparkDataFrame functions
+#' @aliases hint,SparkDataFrame,character-method
+#' @rdname hint
+#' @name hint
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(mtcars)
+#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg")
+#'
+#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl))
+#' }
+#' @note hint since 2.2.0
+setMethod("hint",
+ signature(x = "SparkDataFrame", name = "character"),
+ function(x, name, ...) {
+ parameters <- list(...)
+ stopifnot(all(sapply(parameters, is.character)))
+ jdf <- callJMethod(x@sdf, "hint", name, parameters)
+ dataFrame(jdf)
+ })
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 7ad3993e9ecbc..15ca212acf87f 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -227,7 +227,7 @@ setMethod("cacheRDD",
#'
#' Persist this RDD with the specified storage level. For details of the
#' supported storage levels, refer to
-#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}.
+#'\url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}.
#'
#' @param x The RDD to persist
#' @param newLevel The new storage level to be assigned
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index f5c3a749fe0a1..e3528bc7c3135 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -334,7 +334,7 @@ setMethod("toDF", signature(x = "RDD"),
#'
#' Loads a JSON file, returning the result as a SparkDataFrame
#' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON}
-#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to
+#' ) is supported. For JSON (one record per file), set a named property \code{multiLine} to
#' \code{TRUE}.
#' It goes through the entire dataset once to determine the schema.
#'
@@ -348,7 +348,7 @@ setMethod("toDF", signature(x = "RDD"),
#' sparkR.session()
#' path <- "path/to/file.json"
#' df <- read.json(path)
-#' df <- read.json(path, wholeFile = TRUE)
+#' df <- read.json(path, multiLine = TRUE)
#' df <- jsonFile(path)
#' }
#' @name read.json
@@ -598,7 +598,7 @@ tableToDF <- function(tableName) {
#' df1 <- read.df("path/to/file.json", source = "json")
#' schema <- structType(structField("name", "string"),
#' structField("info", "map"))
-#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE)
+#' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE)
#' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true")
#' }
#' @name read.df
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 9d82814211bc5..7244cc9f9e38e 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -19,7 +19,7 @@
# Creates a SparkR client connection object
# if one doesn't already exist
-connectBackend <- function(hostname, port, timeout) {
+connectBackend <- function(hostname, port, timeout, authSecret) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
@@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {
con <- socketConnection(host = hostname, port = port, server = FALSE,
blocking = TRUE, open = "wb", timeout = timeout)
-
+ doServerAuth(con, authSecret)
assign(".sparkRCon", con, envir = .sparkREnv)
con
}
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 0e99b171cabeb..dc7d37e064b1d 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
stop(paste("Unsupported type for deserialization", type)))
}
-readString <- function(con) {
- stringLen <- readInt(con)
- raw <- readBin(con, raw(), stringLen, endian = "big")
+readStringData <- function(con, len) {
+ raw <- readBin(con, raw(), len, endian = "big")
string <- rawToChar(raw)
Encoding(string) <- "UTF-8"
string
}
+readString <- function(con) {
+ stringLen <- readInt(con)
+ readStringData(con, stringLen)
+}
+
readInt <- function(con) {
readBin(con, integer(), n = 1, endian = "big")
}
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 945676c7f10b3..5ee4216e34554 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -398,7 +398,8 @@ setGeneric("as.data.frame",
standardGeneric("as.data.frame")
})
-#' @rdname attach
+# Do not document the generic because of signature changes across R versions
+#' @noRd
#' @export
setGeneric("attach")
@@ -572,6 +573,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") })
#' @export
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
+#' @rdname hint
+#' @export
+setGeneric("hint", function(x, name, ...) { standardGeneric("hint") })
+
#' @rdname insertInto
#' @export
setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") })
@@ -1356,12 +1361,9 @@ setGeneric("year", function(x) { standardGeneric("year") })
#' @export
setGeneric("fitted")
-#' @param x,y For \code{glm}: logical values indicating whether the response vector
-#' and model matrix used in the fitting process should be returned as
-#' components of the returned value.
-#' @inheritParams stats::glm
-#' @rdname glm
+# Do not carry stats::glm usage and param here, and do not document the generic
#' @export
+#' @noRd
setGeneric("glm")
#' @param object a fitted ML model object.
@@ -1469,7 +1471,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml")
#' @rdname awaitTermination
#' @export
-setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") })
+setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") })
#' @rdname isActive
#' @export
diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R
index 4ca7aa664e023..04dc7562e5346 100644
--- a/R/pkg/R/install.R
+++ b/R/pkg/R/install.R
@@ -152,6 +152,11 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL,
})
if (!tarExists || overwrite || !success) {
unlink(packageLocalPath)
+ if (success) {
+ # if tar file was not there before (or it was, but we are told to overwrite it),
+ # and untar is successful - set a flag that we have downloaded (and untar) Spark package.
+ assign(".sparkDownloaded", TRUE, envir = .sparkREnv)
+ }
}
if (!success) stop("Extract archive failed.")
message("DONE.")
@@ -266,11 +271,16 @@ hadoopVersionName <- function(hadoopVersion) {
# The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and
# adapt to Spark context
+# see also sparkCacheRelPathLength()
sparkCachePath <- function() {
- if (.Platform$OS.type == "windows") {
+ if (is_windows()) {
winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA)
if (is.na(winAppPath)) {
- stop(paste("%LOCALAPPDATA% not found.",
+ message("%LOCALAPPDATA% not found. Falling back to %USERPROFILE%.")
+ winAppPath <- Sys.getenv("USERPROFILE", unset = NA)
+ }
+ if (is.na(winAppPath)) {
+ stop(paste("%LOCALAPPDATA% and %USERPROFILE% not found.",
"Please define the environment variable",
"or restart and enter an installation path in localDir."))
} else {
@@ -278,7 +288,7 @@ sparkCachePath <- function() {
}
} else if (.Platform$OS.type == "unix") {
if (Sys.info()["sysname"] == "Darwin") {
- path <- file.path(Sys.getenv("HOME"), "Library/Caches", "spark")
+ path <- file.path(Sys.getenv("HOME"), "Library", "Caches", "spark")
} else {
path <- file.path(
Sys.getenv("XDG_CACHE_HOME", file.path(Sys.getenv("HOME"), ".cache")), "spark")
@@ -289,6 +299,16 @@ sparkCachePath <- function() {
normalizePath(path, mustWork = FALSE)
}
+# Length of the Spark cache specific relative path segments for each platform
+# eg. "Apache\Spark\Cache" is 3 in Windows, or "spark" is 1 in unix
+# Must match sparkCachePath() exactly.
+sparkCacheRelPathLength <- function() {
+ if (is_windows()) {
+ 3
+ } else {
+ 1
+ }
+}
installInstruction <- function(mode) {
if (mode == "remote") {
@@ -306,3 +326,22 @@ installInstruction <- function(mode) {
stop(paste0("No instruction found for ", mode, " mode."))
}
}
+
+uninstallDownloadedSpark <- function() {
+ # clean up if Spark was downloaded
+ sparkDownloaded <- getOne(".sparkDownloaded",
+ envir = .sparkREnv,
+ inherits = TRUE,
+ ifnotfound = FALSE)
+ sparkDownloadedDir <- Sys.getenv("SPARK_HOME")
+ if (sparkDownloaded && nchar(sparkDownloadedDir) > 0) {
+ unlink(sparkDownloadedDir, recursive = TRUE, force = TRUE)
+
+ dirs <- traverseParentDirs(sparkCachePath(), sparkCacheRelPathLength())
+ lapply(dirs, function(d) {
+ if (length(list.files(d, all.files = TRUE, include.dirs = TRUE, no.. = TRUE)) == 0) {
+ unlink(d, recursive = TRUE, force = TRUE)
+ }
+ })
+ }
+}
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 4db9cc30fb0c1..bdcc0818d139d 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -46,22 +46,25 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))
-#' linear SVM Model
+#' Linear SVM Model
#'
-#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package
+#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
+#' Currently only supports binary classification model with linear kernel.
#' Users can print, make predictions on the produced model and save the model to the input path.
#'
#' @param data SparkDataFrame for training.
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param regParam The regularization parameter.
+#' @param regParam The regularization parameter. Only supports L2 regularization currently.
#' @param maxIter Maximum iteration number.
#' @param tol Convergence tolerance of iterations.
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients
#' of models will be always returned on the original scale, so it will be transparent for
#' users. Note that with/without standardization, the models should be always converged
#' to the same solution when no regularization is applied.
-#' @param threshold The threshold in binary classification, in range [0, 1].
+#' @param threshold The threshold in binary classification applied to the linear model prediction.
+#' This threshold can be any real number, where Inf will make all predictions 0.0
+#' and -Inf will make all predictions 1.0.
#' @param weightCol The weight column name.
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
@@ -111,10 +114,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
new("LinearSVCModel", jobj = jobj)
})
-# Predicted values based on an LinearSVCModel model
+# Predicted values based on a LinearSVCModel model
#' @param newData a SparkDataFrame for testing.
-#' @return \code{predict} returns the predicted values based on an LinearSVCModel.
+#' @return \code{predict} returns the predicted values based on a LinearSVCModel.
#' @rdname spark.svmLinear
#' @aliases predict,LinearSVCModel,SparkDataFrame-method
#' @export
@@ -124,13 +127,12 @@ setMethod("predict", signature(object = "LinearSVCModel"),
predict_internal(object, newData)
})
-# Get the summary of an LinearSVCModel
+# Get the summary of a LinearSVCModel
-#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}.
+#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}.
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list includes \code{coefficients} (coefficients of the fitted model),
-#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes),
-#' \code{numFeatures} (number of features).
+#' \code{numClasses} (number of classes), \code{numFeatures} (number of features).
#' @rdname spark.svmLinear
#' @aliases summary,LinearSVCModel-method
#' @export
@@ -138,22 +140,14 @@ setMethod("predict", signature(object = "LinearSVCModel"),
setMethod("summary", signature(object = "LinearSVCModel"),
function(object) {
jobj <- object@jobj
- features <- callJMethod(jobj, "features")
- labels <- callJMethod(jobj, "labels")
- coefficients <- callJMethod(jobj, "coefficients")
- nCol <- length(coefficients) / length(features)
- coefficients <- matrix(unlist(coefficients), ncol = nCol)
- intercept <- callJMethod(jobj, "intercept")
+ features <- callJMethod(jobj, "rFeatures")
+ coefficients <- callJMethod(jobj, "rCoefficients")
+ coefficients <- as.matrix(unlist(coefficients))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
numClasses <- callJMethod(jobj, "numClasses")
numFeatures <- callJMethod(jobj, "numFeatures")
- if (nCol == 1) {
- colnames(coefficients) <- c("Estimate")
- } else {
- colnames(coefficients) <- unlist(labels)
- }
- rownames(coefficients) <- unlist(features)
- list(coefficients = coefficients, intercept = intercept,
- numClasses = numClasses, numFeatures = numFeatures)
+ list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures)
})
# Save fitted LinearSVCModel to the input path
diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R
index d59c890f3e5fd..352e37199a2c0 100644
--- a/R/pkg/R/mllib_regression.R
+++ b/R/pkg/R/mllib_regression.R
@@ -169,6 +169,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' @param link.power the index of the power link function in the Tweedie family.
#' @return \code{glm} returns a fitted generalized linear model.
#' @rdname glm
+#' @aliases glm
#' @export
#' @examples
#' \dontrun{
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index d0a12b7ecec65..22a99e223e3bb 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
" please use the --packages commandline instead", sep = ","))
}
backendPort <- existingPort
+ authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
+ if (nchar(authSecret) == 0) {
+ stop("Auth secret not provided in environment.")
+ }
} else {
path <- tempfile(pattern = "backend_port")
submitOps <- getClientModeSparkSubmitOpts(
@@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)
+
+ # Don't use readString() so that we can provide a useful
+ # error message if the R and Java versions are mismatched.
+ authSecretLen <- readInt(f)
+ if (length(authSecretLen) == 0 || authSecretLen == 0) {
+ stop("Unexpected EOF in JVM connection data. Mismatched versions?")
+ }
+ authSecret <- readStringData(f, authSecretLen)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
- length(rLibPath) != 1) {
+ length(rLibPath) != 1 || length(authSecret) == 0) {
stop("JVM failed to launch")
}
- assign(".monitorConn",
- socketConnection(port = monitorPort, timeout = connectionTimeout),
- envir = .sparkREnv)
+
+ monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
+ timeout = connectionTimeout, open = "wb")
+ doServerAuth(monitorConn, authSecret)
+
+ assign(".monitorConn", monitorConn, envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -208,7 +223,7 @@ sparkR.sparkContext <- function(
.sparkREnv$backendPort <- backendPort
tryCatch({
- connectBackend("localhost", backendPort, timeout = connectionTimeout)
+ connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
},
error = function(err) {
stop("Failed to connect JVM\n")
@@ -420,6 +435,18 @@ sparkR.session <- function(
enableHiveSupport)
assign(".sparkRsession", sparkSession, envir = .sparkREnv)
}
+
+ # Check if version number of SparkSession matches version number of SparkR package
+ jvmVersion <- callJMethod(sparkSession, "version")
+ # Remove -SNAPSHOT from jvm versions
+ jvmVersionStrip <- gsub("-SNAPSHOT", "", jvmVersion)
+ rPackageVersion <- paste0(packageVersion("SparkR"))
+
+ if (jvmVersionStrip != rPackageVersion) {
+ warning(paste("Version mismatch between Spark JVM and SparkR package. JVM version was",
+ jvmVersion, ", while R package version was", rPackageVersion))
+ }
+
sparkSession
}
@@ -620,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
NULL
}
}
+
+# Utility function for sending auth data over a socket and checking the server's reply.
+doServerAuth <- function(con, authSecret) {
+ if (nchar(authSecret) == 0) {
+ stop("Auth secret not provided.")
+ }
+ writeString(con, authSecret)
+ flush(con)
+ reply <- readString(con)
+ if (reply != "ok") {
+ close(con)
+ stop("Unexpected reply from server.")
+ }
+}
diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R
index e353d2dd07c3d..8390bd5e6de72 100644
--- a/R/pkg/R/streaming.R
+++ b/R/pkg/R/streaming.R
@@ -169,8 +169,10 @@ setMethod("isActive",
#' immediately.
#'
#' @param x a StreamingQuery.
-#' @param timeout time to wait in milliseconds
-#' @return TRUE if query has terminated within the timeout period.
+#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery}
+#' is called or an error has occured.
+#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not
+#' specified.
#' @rdname awaitTermination
#' @name awaitTermination
#' @aliases awaitTermination,StreamingQuery-method
@@ -182,8 +184,12 @@ setMethod("isActive",
#' @note experimental
setMethod("awaitTermination",
signature(x = "StreamingQuery"),
- function(x, timeout) {
- handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout))
+ function(x, timeout = NULL) {
+ if (is.null(timeout)) {
+ invisible(handledCallJMethod(x@ssq, "awaitTermination"))
+ } else {
+ handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout))
+ }
})
#' stopQuery
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index fbc89e98847bf..da5f3cf5bdc32 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -899,3 +899,28 @@ basenameSansExtFromUrl <- function(url) {
isAtomicLengthOne <- function(x) {
is.atomic(x) && length(x) == 1
}
+
+is_windows <- function() {
+ .Platform$OS.type == "windows"
+}
+
+hadoop_home_set <- function() {
+ !identical(Sys.getenv("HADOOP_HOME"), "")
+}
+
+windows_with_hadoop <- function() {
+ !is_windows() || hadoop_home_set()
+}
+
+# get0 not supported before R 3.2.0
+getOne <- function(x, envir, inherits = TRUE, ifnotfound = NULL) {
+ mget(x[1L], envir = envir, inherits = inherits, ifnotfound = list(ifnotfound))[[1L]]
+}
+
+# Returns a vector of parent directories, traversing up count times, starting with a full path
+# eg. traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) should return
+# this "/Users/user/Library/Caches/spark/spark2.2"
+# and "/Users/user/Library/Caches/spark"
+traverseParentDirs <- function(x, count) {
+ if (dirname(x) == x || count <= 0) x else c(x, Recall(dirname(x), count - 1))
+}
diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R
new file mode 100644
index 0000000000000..823d26f12feee
--- /dev/null
+++ b/R/pkg/inst/tests/testthat/test_basic.R
@@ -0,0 +1,92 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("basic tests for CRAN")
+
+test_that("create DataFrame from list or data.frame", {
+ sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE,
+ sparkConfig = sparkRTestConfig)
+
+ i <- 4
+ df <- createDataFrame(data.frame(dummy = 1:i))
+ expect_equal(count(df), i)
+
+ l <- list(list(a = 1, b = 2), list(a = 3, b = 4))
+ df <- createDataFrame(l)
+ expect_equal(columns(df), c("a", "b"))
+
+ a <- 1:3
+ b <- c("a", "b", "c")
+ ldf <- data.frame(a, b)
+ df <- createDataFrame(ldf)
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+ expect_equal(count(df), 3)
+ ldf2 <- collect(df)
+ expect_equal(ldf$a, ldf2$a)
+
+ mtcarsdf <- createDataFrame(mtcars)
+ expect_equivalent(collect(mtcarsdf), mtcars)
+
+ bytes <- as.raw(c(1, 2, 3))
+ df <- createDataFrame(list(list(bytes)))
+ expect_equal(collect(df)[[1]][[1]], bytes)
+
+ sparkR.session.stop()
+})
+
+test_that("spark.glm and predict", {
+ sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE,
+ sparkConfig = sparkRTestConfig)
+
+ training <- suppressWarnings(createDataFrame(iris))
+ # gaussian family
+ model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species)
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ vals <- collect(select(prediction, "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+ # Gamma family
+ x <- runif(100, -1, 1)
+ y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
+ df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
+ model <- glm(y ~ x, family = Gamma, df)
+ out <- capture.output(print(summary(model)))
+ expect_true(any(grepl("Dispersion parameter for gamma family", out)))
+
+ # tweedie family
+ model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
+ family = "tweedie", var.power = 1.2, link.power = 0.0)
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ vals <- collect(select(prediction, "prediction"))
+
+ # manual calculation of the R predicted values to avoid dependence on statmod
+ #' library(statmod)
+ #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris,
+ #' family = tweedie(var.power = 1.2, link.power = 0.0))
+ #' print(coef(rModel))
+
+ rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174)
+ rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species,
+ data = iris) %*% rCoef))
+ expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals)
+
+ sparkR.session.stop()
+})
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index 3a318b71ea06d..ec9a8f1ee1c95 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
- port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
+ port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)
+
+SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
while (TRUE) {
ready <- socketSelect(list(inputCon))
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 03e7450147865..eb6453fc16976 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
- port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
+ port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
+SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
+
outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
+SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/tests/fulltests/jarTest.R
similarity index 96%
rename from R/pkg/inst/tests/testthat/jarTest.R
rename to R/pkg/tests/fulltests/jarTest.R
index c9615c8d4faf6..e2241e03b55f8 100644
--- a/R/pkg/inst/tests/testthat/jarTest.R
+++ b/R/pkg/tests/fulltests/jarTest.R
@@ -16,7 +16,7 @@
#
library(SparkR)
-sc <- sparkR.session()
+sc <- sparkR.session(master = "local[1]")
helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass",
"helloWorld",
diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/tests/fulltests/packageInAJarTest.R
similarity index 96%
rename from R/pkg/inst/tests/testthat/packageInAJarTest.R
rename to R/pkg/tests/fulltests/packageInAJarTest.R
index 4bc935c79eb0f..ac706261999fb 100644
--- a/R/pkg/inst/tests/testthat/packageInAJarTest.R
+++ b/R/pkg/tests/fulltests/packageInAJarTest.R
@@ -17,7 +17,7 @@
library(SparkR)
library(sparkPackageTest)
-sparkR.session()
+sparkR.session(master = "local[1]")
run1 <- myfunc(5L)
diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R
similarity index 96%
rename from R/pkg/inst/tests/testthat/test_Serde.R
rename to R/pkg/tests/fulltests/test_Serde.R
index b5f6f1b54fa85..6bbd201bf1d82 100644
--- a/R/pkg/inst/tests/testthat/test_Serde.R
+++ b/R/pkg/tests/fulltests/test_Serde.R
@@ -17,7 +17,7 @@
context("SerDe functionality")
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("SerDe of primitive types", {
x <- callJStatic("SparkRHandler", "echo", 1L)
diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R
similarity index 96%
rename from R/pkg/inst/tests/testthat/test_Windows.R
rename to R/pkg/tests/fulltests/test_Windows.R
index 1d777ddb286df..b2ec6c67311db 100644
--- a/R/pkg/inst/tests/testthat/test_Windows.R
+++ b/R/pkg/tests/fulltests/test_Windows.R
@@ -17,7 +17,7 @@
context("Windows-specific tests")
test_that("sparkJars tag in SparkContext", {
- if (.Platform$OS.type != "windows") {
+ if (!is_windows()) {
skip("This test is only for Windows, skipped")
}
diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R
similarity index 97%
rename from R/pkg/inst/tests/testthat/test_binaryFile.R
rename to R/pkg/tests/fulltests/test_binaryFile.R
index b5c279e3156e5..758b174b8787c 100644
--- a/R/pkg/inst/tests/testthat/test_binaryFile.R
+++ b/R/pkg/tests/fulltests/test_binaryFile.R
@@ -18,7 +18,7 @@
context("functions on binary files")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockFile <- c("Spark is pretty.", "Spark is awesome.")
diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R
similarity index 97%
rename from R/pkg/inst/tests/testthat/test_binary_function.R
rename to R/pkg/tests/fulltests/test_binary_function.R
index 59cb2e6204405..442bed509bb1d 100644
--- a/R/pkg/inst/tests/testthat/test_binary_function.R
+++ b/R/pkg/tests/fulltests/test_binary_function.R
@@ -18,7 +18,7 @@
context("binary functions")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Data
diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R
similarity index 95%
rename from R/pkg/inst/tests/testthat/test_broadcast.R
rename to R/pkg/tests/fulltests/test_broadcast.R
index 65f204d096f43..5f74d4960451a 100644
--- a/R/pkg/inst/tests/testthat/test_broadcast.R
+++ b/R/pkg/tests/fulltests/test_broadcast.R
@@ -18,7 +18,7 @@
context("broadcast variables")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Partitioned data
diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/tests/fulltests/test_client.R
similarity index 100%
rename from R/pkg/inst/tests/testthat/test_client.R
rename to R/pkg/tests/fulltests/test_client.R
diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/tests/fulltests/test_context.R
similarity index 94%
rename from R/pkg/inst/tests/testthat/test_context.R
rename to R/pkg/tests/fulltests/test_context.R
index c847113491113..73b0f5355518d 100644
--- a/R/pkg/inst/tests/testthat/test_context.R
+++ b/R/pkg/tests/fulltests/test_context.R
@@ -56,7 +56,7 @@ test_that("Check masked functions", {
test_that("repeatedly starting and stopping SparkR", {
for (i in 1:4) {
- sc <- suppressWarnings(sparkR.init())
+ sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster))
rdd <- parallelize(sc, 1:20, 2L)
expect_equal(countRDD(rdd), 20)
suppressWarnings(sparkR.stop())
@@ -65,7 +65,7 @@ test_that("repeatedly starting and stopping SparkR", {
test_that("repeatedly starting and stopping SparkSession", {
for (i in 1:4) {
- sparkR.session(enableHiveSupport = FALSE)
+ sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
df <- createDataFrame(data.frame(dummy = 1:i))
expect_equal(count(df), i)
sparkR.session.stop()
@@ -73,12 +73,12 @@ test_that("repeatedly starting and stopping SparkSession", {
})
test_that("rdd GC across sparkR.stop", {
- sc <- sparkR.sparkContext() # sc should get id 0
+ sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0
rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1
rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2
sparkR.session.stop()
- sc <- sparkR.sparkContext() # sc should get id 0 again
+ sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again
# GC rdd1 before creating rdd3 and rdd2 after
rm(rdd1)
@@ -96,7 +96,7 @@ test_that("rdd GC across sparkR.stop", {
})
test_that("job group functions can be called", {
- sc <- sparkR.sparkContext()
+ sc <- sparkR.sparkContext(master = sparkRTestMaster)
setJobGroup("groupId", "job description", TRUE)
cancelJobGroup("groupId")
clearJobGroup()
@@ -108,7 +108,7 @@ test_that("job group functions can be called", {
})
test_that("utility function can be called", {
- sparkR.sparkContext()
+ sparkR.sparkContext(master = sparkRTestMaster)
setLogLevel("ERROR")
sparkR.session.stop()
})
@@ -161,14 +161,14 @@ test_that("sparkJars sparkPackages as comma-separated strings", {
})
test_that("spark.lapply should perform simple transforms", {
- sparkR.sparkContext()
+ sparkR.sparkContext(master = sparkRTestMaster)
doubled <- spark.lapply(1:10, function(x) { 2 * x })
expect_equal(doubled, as.list(2 * 1:10))
sparkR.session.stop()
})
test_that("add and get file to be downloaded with Spark job on every node", {
- sparkR.sparkContext()
+ sparkR.sparkContext(master = sparkRTestMaster)
# Test add file.
path <- tempfile(pattern = "hello", fileext = ".txt")
filename <- basename(path)
diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R
similarity index 95%
rename from R/pkg/inst/tests/testthat/test_includePackage.R
rename to R/pkg/tests/fulltests/test_includePackage.R
index 563ea298c2dd8..f4ea0d1b5cb27 100644
--- a/R/pkg/inst/tests/testthat/test_includePackage.R
+++ b/R/pkg/tests/fulltests/test_includePackage.R
@@ -18,7 +18,7 @@
context("include R packages")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Partitioned data
diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/tests/fulltests/test_jvm_api.R
similarity index 93%
rename from R/pkg/inst/tests/testthat/test_jvm_api.R
rename to R/pkg/tests/fulltests/test_jvm_api.R
index 7348c893d0af3..8b3b4f73de170 100644
--- a/R/pkg/inst/tests/testthat/test_jvm_api.R
+++ b/R/pkg/tests/fulltests/test_jvm_api.R
@@ -17,7 +17,7 @@
context("JVM API")
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("Create and call methods on object", {
jarr <- sparkR.newJObject("java.util.ArrayList")
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R
similarity index 82%
rename from R/pkg/inst/tests/testthat/test_mllib_classification.R
rename to R/pkg/tests/fulltests/test_mllib_classification.R
index 459254d271a58..892d8a69ae693 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/tests/fulltests/test_mllib_classification.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib classification algorithms, except for tree-based algorithms")
# Tests for MLlib classification algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
absoluteSparkPath <- function(x) {
sparkHome <- sparkR.conf("spark.home")
@@ -38,9 +38,8 @@ test_that("spark.svmLinear", {
expect_true(class(summary$coefficients[, 1]) == "numeric")
coefs <- summary$coefficients[, "Estimate"]
- expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085)
+ expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085)
expect_true(all(abs(coefs - expected_coefs) < 0.1))
- expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2)
# Test prediction with string label
prediction <- predict(model, training)
@@ -50,25 +49,26 @@ test_that("spark.svmLinear", {
expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected)
# Test model save and load
- modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- coefs <- summary(model)$coefficients
- coefs2 <- summary(model2)$coefficients
- expect_equal(coefs, coefs2)
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ coefs <- summary(model)$coefficients
+ coefs2 <- summary(model2)$coefficients
+ expect_equal(coefs, coefs2)
+ unlink(modelPath)
+ }
# Test prediction with numeric label
label <- c(0.0, 0.0, 0.0, 1.0, 1.0)
feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
data <- as.data.frame(cbind(label, feature))
df <- createDataFrame(data)
- model <- spark.svmLinear(df, label ~ feature, regParam = 0.1)
+ model <- spark.svmLinear(df, label ~ feature, regParam = 0.1, maxIter = 5)
prediction <- collect(select(predict(model, df), "prediction"))
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
-
})
test_that("spark.logit", {
@@ -128,15 +128,17 @@ test_that("spark.logit", {
expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1))
# Test model save and load
- modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- coefs <- summary(model)$coefficients
- coefs2 <- summary(model2)$coefficients
- expect_equal(coefs, coefs2)
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ coefs <- summary(model)$coefficients
+ coefs2 <- summary(model2)$coefficients
+ expect_equal(coefs, coefs2)
+ unlink(modelPath)
+ }
# R code to reproduce the result.
# nolint start
@@ -243,19 +245,21 @@ test_that("spark.mlp", {
expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0"))
# Test model save/load
- modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- summary2 <- summary(model2)
-
- expect_equal(summary2$numOfInputs, 4)
- expect_equal(summary2$numOfOutputs, 3)
- expect_equal(summary2$layers, c(4, 5, 4, 3))
- expect_equal(length(summary2$weights), 64)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ summary2 <- summary(model2)
+
+ expect_equal(summary2$numOfInputs, 4)
+ expect_equal(summary2$numOfOutputs, 3)
+ expect_equal(summary2$layers, c(4, 5, 4, 3))
+ expect_equal(length(summary2$weights), 64)
+
+ unlink(modelPath)
+ }
# Test default parameter
model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3))
@@ -284,22 +288,11 @@ test_that("spark.mlp", {
c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
# test initialWeights
- model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights =
+ model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights =
c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9))
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
expect_equal(head(mlpPredictions$prediction, 10),
- c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
-
- model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights =
- c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0))
- mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
- expect_equal(head(mlpPredictions$prediction, 10),
- c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
-
- model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2)
- mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
- expect_equal(head(mlpPredictions$prediction, 10),
- c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0"))
+ c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0"))
# Test formula works well
df <- suppressWarnings(createDataFrame(iris))
@@ -310,8 +303,6 @@ test_that("spark.mlp", {
expect_equal(summary$numOfOutputs, 3)
expect_equal(summary$layers, c(4, 3))
expect_equal(length(summary$weights), 15)
- expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413,
- -10.2376130), tolerance = 1e-6)
})
test_that("spark.naiveBayes", {
@@ -367,16 +358,18 @@ test_that("spark.naiveBayes", {
"Yes", "Yes", "No", "No"))
# Test model save/load
- modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp")
- write.ml(m, modelPath)
- expect_error(write.ml(m, modelPath))
- write.ml(m, modelPath, overwrite = TRUE)
- m2 <- read.ml(modelPath)
- s2 <- summary(m2)
- expect_equal(s$apriori, s2$apriori)
- expect_equal(s$tables, s2$tables)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp")
+ write.ml(m, modelPath)
+ expect_error(write.ml(m, modelPath))
+ write.ml(m, modelPath, overwrite = TRUE)
+ m2 <- read.ml(modelPath)
+ s2 <- summary(m2)
+ expect_equal(s$apriori, s2$apriori)
+ expect_equal(s$tables, s2$tables)
+
+ unlink(modelPath)
+ }
# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R
similarity index 79%
rename from R/pkg/inst/tests/testthat/test_mllib_clustering.R
rename to R/pkg/tests/fulltests/test_mllib_clustering.R
index 1661e987b730f..4110e13da4948 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/tests/fulltests/test_mllib_clustering.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib clustering algorithms")
# Tests for MLlib clustering algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
absoluteSparkPath <- function(x) {
sparkHome <- sparkR.conf("spark.home")
@@ -53,18 +53,20 @@ test_that("spark.bisectingKmeans", {
c(0, 1, 2, 3))
# Test model save/load
- modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- summary2 <- summary(model2)
- expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
- expect_equal(summary.model$coefficients, summary2$coefficients)
- expect_true(!summary.model$is.loaded)
- expect_true(summary2$is.loaded)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ summary2 <- summary(model2)
+ expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
+ expect_equal(summary.model$coefficients, summary2$coefficients)
+ expect_true(!summary.model$is.loaded)
+ expect_true(summary2$is.loaded)
+
+ unlink(modelPath)
+ }
})
test_that("spark.gaussianMixture", {
@@ -125,18 +127,20 @@ test_that("spark.gaussianMixture", {
expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1))
# Test model save/load
- modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- expect_equal(stats$lambda, stats2$lambda)
- expect_equal(unlist(stats$mu), unlist(stats2$mu))
- expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
- expect_equal(unlist(stats$loglik), unlist(stats2$loglik))
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$lambda, stats2$lambda)
+ expect_equal(unlist(stats$mu), unlist(stats2$mu))
+ expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+ expect_equal(unlist(stats$loglik), unlist(stats2$loglik))
+
+ unlink(modelPath)
+ }
})
test_that("spark.kmeans", {
@@ -171,18 +175,20 @@ test_that("spark.kmeans", {
expect_true(class(summary.model$coefficients[1, ]) == "numeric")
# Test model save/load
- modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- summary2 <- summary(model2)
- expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
- expect_equal(summary.model$coefficients, summary2$coefficients)
- expect_true(!summary.model$is.loaded)
- expect_true(summary2$is.loaded)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ summary2 <- summary(model2)
+ expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
+ expect_equal(summary.model$coefficients, summary2$coefficients)
+ expect_true(!summary.model$is.loaded)
+ expect_true(summary2$is.loaded)
+
+ unlink(modelPath)
+ }
# Test Kmeans on dataset that is sensitive to seed value
col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
@@ -236,22 +242,24 @@ test_that("spark.lda with libsvm", {
expect_true(logPrior <= 0 & !is.na(logPrior))
# Test model save/load
- modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
-
- expect_true(stats2$isDistributed)
- expect_equal(logLikelihood, stats2$logLikelihood)
- expect_equal(logPerplexity, stats2$logPerplexity)
- expect_equal(vocabSize, stats2$vocabSize)
- expect_equal(vocabulary, stats2$vocabulary)
- expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood)
- expect_equal(logPrior, stats2$logPrior)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_true(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_equal(vocabulary, stats2$vocabulary)
+ expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood)
+ expect_equal(logPrior, stats2$logPrior)
+
+ unlink(modelPath)
+ }
})
test_that("spark.lda with text input", {
diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R
similarity index 85%
rename from R/pkg/inst/tests/testthat/test_mllib_fpm.R
rename to R/pkg/tests/fulltests/test_mllib_fpm.R
index c38f1133897dd..69dda52f0c279 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R
+++ b/R/pkg/tests/fulltests/test_mllib_fpm.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib frequent pattern mining")
# Tests for MLlib frequent pattern mining algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("spark.fpGrowth", {
data <- selectExpr(createDataFrame(data.frame(items = c(
@@ -62,15 +62,17 @@ test_that("spark.fpGrowth", {
expect_equivalent(expected_predictions, collect(predict(model, new_data)))
- modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp")
- write.ml(model, modelPath, overwrite = TRUE)
- loaded_model <- read.ml(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp")
+ write.ml(model, modelPath, overwrite = TRUE)
+ loaded_model <- read.ml(modelPath)
- expect_equivalent(
- itemsets,
- collect(spark.freqItemsets(loaded_model)))
+ expect_equivalent(
+ itemsets,
+ collect(spark.freqItemsets(loaded_model)))
- unlink(modelPath)
+ unlink(modelPath)
+ }
model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8)
expect_equal(
diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R
similarity index 59%
rename from R/pkg/inst/tests/testthat/test_mllib_recommendation.R
rename to R/pkg/tests/fulltests/test_mllib_recommendation.R
index 6b1040db93050..4d919c9d746b0 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R
+++ b/R/pkg/tests/fulltests/test_mllib_recommendation.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib recommendation algorithms")
# Tests for MLlib recommendation algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("spark.als", {
data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
@@ -37,29 +37,31 @@ test_that("spark.als", {
tolerance = 1e-4)
# Test model save/load
- modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- expect_equal(stats2$rating, "score")
- userFactors <- collect(stats$userFactors)
- itemFactors <- collect(stats$itemFactors)
- userFactors2 <- collect(stats2$userFactors)
- itemFactors2 <- collect(stats2$itemFactors)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats2$rating, "score")
+ userFactors <- collect(stats$userFactors)
+ itemFactors <- collect(stats$itemFactors)
+ userFactors2 <- collect(stats2$userFactors)
+ itemFactors2 <- collect(stats2$itemFactors)
- orderUser <- order(userFactors$id)
- orderUser2 <- order(userFactors2$id)
- expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
- expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
+ orderUser <- order(userFactors$id)
+ orderUser2 <- order(userFactors2$id)
+ expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
+ expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
- orderItem <- order(itemFactors$id)
- orderItem2 <- order(itemFactors2$id)
- expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
- expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
+ orderItem <- order(itemFactors$id)
+ orderItem2 <- order(itemFactors2$id)
+ expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
+ expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
- unlink(modelPath)
+ unlink(modelPath)
+ }
})
sparkR.session.stop()
diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R
similarity index 95%
rename from R/pkg/inst/tests/testthat/test_mllib_regression.R
rename to R/pkg/tests/fulltests/test_mllib_regression.R
index 3e9ad77198073..82472c92b9965 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_regression.R
+++ b/R/pkg/tests/fulltests/test_mllib_regression.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib regression algorithms, except for tree-based algorithms")
# Tests for MLlib regression algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("formula of spark.glm", {
training <- suppressWarnings(createDataFrame(iris))
@@ -389,14 +389,16 @@ test_that("spark.isoreg", {
expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0))
# Test model save/load
- modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- expect_equal(result, summary(model2))
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ expect_equal(result, summary(model2))
+
+ unlink(modelPath)
+ }
})
test_that("spark.survreg", {
@@ -438,17 +440,19 @@ test_that("spark.survreg", {
2.390146, 2.891269, 2.891269), tolerance = 1e-4)
# Test model save/load
- modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- coefs2 <- as.vector(stats2$coefficients[, 1])
- expect_equal(coefs, coefs2)
- expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients))
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ coefs2 <- as.vector(stats2$coefficients[, 1])
+ expect_equal(coefs, coefs2)
+ expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients))
+
+ unlink(modelPath)
+ }
# Test survival::survreg
if (requireNamespace("survival", quietly = TRUE)) {
diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/tests/fulltests/test_mllib_stat.R
similarity index 96%
rename from R/pkg/inst/tests/testthat/test_mllib_stat.R
rename to R/pkg/tests/fulltests/test_mllib_stat.R
index beb148e7702fd..1600833a5d03a 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_stat.R
+++ b/R/pkg/tests/fulltests/test_mllib_stat.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib statistics algorithms")
# Tests for MLlib statistics algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
test_that("spark.kstest", {
data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5))
diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
similarity index 66%
rename from R/pkg/inst/tests/testthat/test_mllib_tree.R
rename to R/pkg/tests/fulltests/test_mllib_tree.R
index e0802a9b02d13..267aa80afdd26 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -20,7 +20,7 @@ library(testthat)
context("MLlib tree-based algorithms")
# Tests for MLlib tree-based algorithms in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
absoluteSparkPath <- function(x) {
sparkHome <- sparkR.conf("spark.home")
@@ -44,21 +44,23 @@ test_that("spark.gbt", {
expect_equal(stats$numFeatures, 6)
expect_equal(length(stats$treeWeights), 20)
- modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- expect_equal(stats$formula, stats2$formula)
- expect_equal(stats$numFeatures, stats2$numFeatures)
- expect_equal(stats$features, stats2$features)
- expect_equal(stats$featureImportances, stats2$featureImportances)
- expect_equal(stats$maxDepth, stats2$maxDepth)
- expect_equal(stats$numTrees, stats2$numTrees)
- expect_equal(stats$treeWeights, stats2$treeWeights)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$formula, stats2$formula)
+ expect_equal(stats$numFeatures, stats2$numFeatures)
+ expect_equal(stats$features, stats2$features)
+ expect_equal(stats$featureImportances, stats2$featureImportances)
+ expect_equal(stats$maxDepth, stats2$maxDepth)
+ expect_equal(stats$numTrees, stats2$numTrees)
+ expect_equal(stats$treeWeights, stats2$treeWeights)
+
+ unlink(modelPath)
+ }
# classification
# label must be binary - GBTClassifier currently only supports binary classification.
@@ -76,17 +78,19 @@ test_that("spark.gbt", {
expect_equal(length(grep("setosa", predictions)), 50)
expect_equal(length(grep("versicolor", predictions)), 50)
- modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- expect_equal(stats$depth, stats2$depth)
- expect_equal(stats$numNodes, stats2$numNodes)
- expect_equal(stats$numClasses, stats2$numClasses)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$depth, stats2$depth)
+ expect_equal(stats$numNodes, stats2$numNodes)
+ expect_equal(stats$numClasses, stats2$numClasses)
+
+ unlink(modelPath)
+ }
iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
df <- suppressWarnings(createDataFrame(iris2))
@@ -99,10 +103,12 @@ test_that("spark.gbt", {
expect_equal(stats$maxDepth, 5)
# spark.gbt classification can work on libsvm data
- data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
- source = "libsvm")
- model <- spark.gbt(data, label ~ features, "classification")
- expect_equal(summary(model)$numFeatures, 692)
+ if (windows_with_hadoop()) {
+ data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
+ source = "libsvm")
+ model <- spark.gbt(data, label ~ features, "classification")
+ expect_equal(summary(model)$numFeatures, 692)
+ }
})
test_that("spark.randomForest", {
@@ -136,21 +142,23 @@ test_that("spark.randomForest", {
expect_equal(stats$numTrees, 20)
expect_equal(stats$maxDepth, 5)
- modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- expect_equal(stats$formula, stats2$formula)
- expect_equal(stats$numFeatures, stats2$numFeatures)
- expect_equal(stats$features, stats2$features)
- expect_equal(stats$featureImportances, stats2$featureImportances)
- expect_equal(stats$numTrees, stats2$numTrees)
- expect_equal(stats$maxDepth, stats2$maxDepth)
- expect_equal(stats$treeWeights, stats2$treeWeights)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$formula, stats2$formula)
+ expect_equal(stats$numFeatures, stats2$numFeatures)
+ expect_equal(stats$features, stats2$features)
+ expect_equal(stats$featureImportances, stats2$featureImportances)
+ expect_equal(stats$numTrees, stats2$numTrees)
+ expect_equal(stats$maxDepth, stats2$maxDepth)
+ expect_equal(stats$treeWeights, stats2$treeWeights)
+
+ unlink(modelPath)
+ }
# classification
data <- suppressWarnings(createDataFrame(iris))
@@ -168,17 +176,19 @@ test_that("spark.randomForest", {
expect_equal(length(grep("setosa", predictions)), 50)
expect_equal(length(grep("versicolor", predictions)), 50)
- modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
- write.ml(model, modelPath)
- expect_error(write.ml(model, modelPath))
- write.ml(model, modelPath, overwrite = TRUE)
- model2 <- read.ml(modelPath)
- stats2 <- summary(model2)
- expect_equal(stats$depth, stats2$depth)
- expect_equal(stats$numNodes, stats2$numNodes)
- expect_equal(stats$numClasses, stats2$numClasses)
-
- unlink(modelPath)
+ if (windows_with_hadoop()) {
+ modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$depth, stats2$depth)
+ expect_equal(stats$numNodes, stats2$numNodes)
+ expect_equal(stats$numClasses, stats2$numClasses)
+
+ unlink(modelPath)
+ }
# Test numeric response variable
labelToIndex <- function(species) {
@@ -203,10 +213,12 @@ test_that("spark.randomForest", {
expect_equal(length(grep("2.0", predictions)), 50)
# spark.randomForest classification can work on libsvm data
- data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
- source = "libsvm")
- model <- spark.randomForest(data, label ~ features, "classification")
- expect_equal(summary(model)$numFeatures, 4)
+ if (windows_with_hadoop()) {
+ data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
+ source = "libsvm")
+ model <- spark.randomForest(data, label ~ features, "classification")
+ expect_equal(summary(model)$numFeatures, 4)
+ }
})
sparkR.session.stop()
diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R
similarity index 98%
rename from R/pkg/inst/tests/testthat/test_parallelize_collect.R
rename to R/pkg/tests/fulltests/test_parallelize_collect.R
index 55972e1ba4693..3d122ccaf448f 100644
--- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R
+++ b/R/pkg/tests/fulltests/test_parallelize_collect.R
@@ -33,7 +33,7 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3))
strPairs <- list(list(strList, strList), list(strList, strList))
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Tests
diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R
similarity index 99%
rename from R/pkg/inst/tests/testthat/test_rdd.R
rename to R/pkg/tests/fulltests/test_rdd.R
index b72c801dd958d..6ee1fceffd822 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/tests/fulltests/test_rdd.R
@@ -18,7 +18,7 @@
context("basic RDD functions")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Data
@@ -40,8 +40,8 @@ test_that("first on RDD", {
})
test_that("count and length on RDD", {
- expect_equal(countRDD(rdd), 10)
- expect_equal(lengthRDD(rdd), 10)
+ expect_equal(countRDD(rdd), 10)
+ expect_equal(lengthRDD(rdd), 10)
})
test_that("count by values and keys", {
diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R
similarity index 98%
rename from R/pkg/inst/tests/testthat/test_shuffle.R
rename to R/pkg/tests/fulltests/test_shuffle.R
index d38efab0fd1df..98300c67c415f 100644
--- a/R/pkg/inst/tests/testthat/test_shuffle.R
+++ b/R/pkg/tests/fulltests/test_shuffle.R
@@ -18,7 +18,7 @@
context("partitionBy, groupByKey, reduceByKey etc.")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Data
diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R
similarity index 100%
rename from R/pkg/inst/tests/testthat/test_sparkR.R
rename to R/pkg/tests/fulltests/test_sparkR.R
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
similarity index 92%
rename from R/pkg/inst/tests/testthat/test_sparkSQL.R
rename to R/pkg/tests/fulltests/test_sparkSQL.R
index 6a6c9a809ab13..f774554e5b2b1 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -61,7 +61,11 @@ unsetHiveContext <- function() {
# Tests for SparkSQL functions in SparkR
filesBefore <- list.files(path = sparkRDir, all.files = TRUE)
-sparkSession <- sparkR.session()
+sparkSession <- if (windows_with_hadoop()) {
+ sparkR.session(master = sparkRTestMaster)
+ } else {
+ sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
+ }
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockLines <- c("{\"name\":\"Michael\"}",
@@ -96,6 +100,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}
mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
writeLines(mockLinesMapType, mapTypeJsonPath)
+if (is_windows()) {
+ Sys.setenv(TZ = "GMT")
+}
+
test_that("calling sparkRSQL.init returns existing SQL context", {
sqlContext <- suppressWarnings(sparkRSQL.init(sc))
expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext)
@@ -303,51 +311,53 @@ test_that("createDataFrame uses files for large objects", {
})
test_that("read/write csv as DataFrame", {
- csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
- mockLinesCsv <- c("year,make,model,comment,blank",
- "\"2012\",\"Tesla\",\"S\",\"No comment\",",
- "1997,Ford,E350,\"Go get one now they are going fast\",",
- "2015,Chevy,Volt",
- "NA,Dummy,Placeholder")
- writeLines(mockLinesCsv, csvPath)
-
- # default "header" is false, inferSchema to handle "year" as "int"
- df <- read.df(csvPath, "csv", header = "true", inferSchema = "true")
- expect_equal(count(df), 4)
- expect_equal(columns(df), c("year", "make", "model", "comment", "blank"))
- expect_equal(sort(unlist(collect(where(df, df$year == 2015)))),
- sort(unlist(list(year = 2015, make = "Chevy", model = "Volt"))))
-
- # since "year" is "int", let's skip the NA values
- withoutna <- na.omit(df, how = "any", cols = "year")
- expect_equal(count(withoutna), 3)
-
- unlink(csvPath)
- csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
- mockLinesCsv <- c("year,make,model,comment,blank",
- "\"2012\",\"Tesla\",\"S\",\"No comment\",",
- "1997,Ford,E350,\"Go get one now they are going fast\",",
- "2015,Chevy,Volt",
- "Empty,Dummy,Placeholder")
- writeLines(mockLinesCsv, csvPath)
-
- df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty")
- expect_equal(count(df2), 4)
- withoutna2 <- na.omit(df2, how = "any", cols = "year")
- expect_equal(count(withoutna2), 3)
- expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0)
-
- # writing csv file
- csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv")
- write.df(df2, path = csvPath2, "csv", header = "true")
- df3 <- read.df(csvPath2, "csv", header = "true")
- expect_equal(nrow(df3), nrow(df2))
- expect_equal(colnames(df3), colnames(df2))
- csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]])
- expect_equal(colnames(df3), colnames(csv))
-
- unlink(csvPath)
- unlink(csvPath2)
+ if (windows_with_hadoop()) {
+ csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
+ mockLinesCsv <- c("year,make,model,comment,blank",
+ "\"2012\",\"Tesla\",\"S\",\"No comment\",",
+ "1997,Ford,E350,\"Go get one now they are going fast\",",
+ "2015,Chevy,Volt",
+ "NA,Dummy,Placeholder")
+ writeLines(mockLinesCsv, csvPath)
+
+ # default "header" is false, inferSchema to handle "year" as "int"
+ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true")
+ expect_equal(count(df), 4)
+ expect_equal(columns(df), c("year", "make", "model", "comment", "blank"))
+ expect_equal(sort(unlist(collect(where(df, df$year == 2015)))),
+ sort(unlist(list(year = 2015, make = "Chevy", model = "Volt"))))
+
+ # since "year" is "int", let's skip the NA values
+ withoutna <- na.omit(df, how = "any", cols = "year")
+ expect_equal(count(withoutna), 3)
+
+ unlink(csvPath)
+ csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
+ mockLinesCsv <- c("year,make,model,comment,blank",
+ "\"2012\",\"Tesla\",\"S\",\"No comment\",",
+ "1997,Ford,E350,\"Go get one now they are going fast\",",
+ "2015,Chevy,Volt",
+ "Empty,Dummy,Placeholder")
+ writeLines(mockLinesCsv, csvPath)
+
+ df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty")
+ expect_equal(count(df2), 4)
+ withoutna2 <- na.omit(df2, how = "any", cols = "year")
+ expect_equal(count(withoutna2), 3)
+ expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0)
+
+ # writing csv file
+ csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv")
+ write.df(df2, path = csvPath2, "csv", header = "true")
+ df3 <- read.df(csvPath2, "csv", header = "true")
+ expect_equal(nrow(df3), nrow(df2))
+ expect_equal(colnames(df3), colnames(df2))
+ csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]])
+ expect_equal(colnames(df3), colnames(csv))
+
+ unlink(csvPath)
+ unlink(csvPath2)
+ }
})
test_that("Support other types for options", {
@@ -473,6 +483,12 @@ test_that("create DataFrame with different data types", {
expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
})
+test_that("SPARK-17902: collect() with stringsAsFactors enabled", {
+ df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE))
+ expect_equal(class(iris$Species), class(df$Species))
+ expect_equal(iris$Species, df$Species)
+})
+
test_that("SPARK-17811: can create DataFrame containing NA as date and time", {
df <- data.frame(
id = 1:2,
@@ -570,48 +586,50 @@ test_that("Collect DataFrame with complex types", {
})
test_that("read/write json files", {
- # Test read.df
- df <- read.df(jsonPath, "json")
- expect_is(df, "SparkDataFrame")
- expect_equal(count(df), 3)
-
- # Test read.df with a user defined schema
- schema <- structType(structField("name", type = "string"),
- structField("age", type = "double"))
-
- df1 <- read.df(jsonPath, "json", schema)
- expect_is(df1, "SparkDataFrame")
- expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double")))
-
- # Test loadDF
- df2 <- loadDF(jsonPath, "json", schema)
- expect_is(df2, "SparkDataFrame")
- expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double")))
-
- # Test read.json
- df <- read.json(jsonPath)
- expect_is(df, "SparkDataFrame")
- expect_equal(count(df), 3)
-
- # Test write.df
- jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json")
- write.df(df, jsonPath2, "json", mode = "overwrite")
-
- # Test write.json
- jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json")
- write.json(df, jsonPath3)
-
- # Test read.json()/jsonFile() works with multiple input paths
- jsonDF1 <- read.json(c(jsonPath2, jsonPath3))
- expect_is(jsonDF1, "SparkDataFrame")
- expect_equal(count(jsonDF1), 6)
- # Suppress warnings because jsonFile is deprecated
- jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3)))
- expect_is(jsonDF2, "SparkDataFrame")
- expect_equal(count(jsonDF2), 6)
-
- unlink(jsonPath2)
- unlink(jsonPath3)
+ if (windows_with_hadoop()) {
+ # Test read.df
+ df <- read.df(jsonPath, "json")
+ expect_is(df, "SparkDataFrame")
+ expect_equal(count(df), 3)
+
+ # Test read.df with a user defined schema
+ schema <- structType(structField("name", type = "string"),
+ structField("age", type = "double"))
+
+ df1 <- read.df(jsonPath, "json", schema)
+ expect_is(df1, "SparkDataFrame")
+ expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double")))
+
+ # Test loadDF
+ df2 <- loadDF(jsonPath, "json", schema)
+ expect_is(df2, "SparkDataFrame")
+ expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double")))
+
+ # Test read.json
+ df <- read.json(jsonPath)
+ expect_is(df, "SparkDataFrame")
+ expect_equal(count(df), 3)
+
+ # Test write.df
+ jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json")
+ write.df(df, jsonPath2, "json", mode = "overwrite")
+
+ # Test write.json
+ jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json")
+ write.json(df, jsonPath3)
+
+ # Test read.json()/jsonFile() works with multiple input paths
+ jsonDF1 <- read.json(c(jsonPath2, jsonPath3))
+ expect_is(jsonDF1, "SparkDataFrame")
+ expect_equal(count(jsonDF1), 6)
+ # Suppress warnings because jsonFile is deprecated
+ jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3)))
+ expect_is(jsonDF2, "SparkDataFrame")
+ expect_equal(count(jsonDF2), 6)
+
+ unlink(jsonPath2)
+ unlink(jsonPath3)
+ }
})
test_that("read/write json files - compression option", {
@@ -642,24 +660,27 @@ test_that("jsonRDD() on a RDD with json string", {
})
test_that("test tableNames and tables", {
+ count <- count(listTables())
+
df <- read.json(jsonPath)
createOrReplaceTempView(df, "table1")
- expect_equal(length(tableNames()), 1)
- expect_equal(length(tableNames("default")), 1)
+ expect_equal(length(tableNames()), count + 1)
+ expect_equal(length(tableNames("default")), count + 1)
+
tables <- listTables()
- expect_equal(count(tables), 1)
+ expect_equal(count(tables), count + 1)
expect_equal(count(tables()), count(tables))
expect_true("tableName" %in% colnames(tables()))
expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables())))
suppressWarnings(registerTempTable(df, "table2"))
tables <- listTables()
- expect_equal(count(tables), 2)
+ expect_equal(count(tables), count + 2)
suppressWarnings(dropTempTable("table1"))
expect_true(dropTempView("table2"))
tables <- listTables()
- expect_equal(count(tables), 0)
+ expect_equal(count(tables), count + 0)
})
test_that(
@@ -692,37 +713,39 @@ test_that("test cache, uncache and clearCache", {
expect_true(dropTempView("table1"))
expect_error(uncacheTable("foo"),
- "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'")
+ "Error in uncacheTable : analysis error - Table or view not found: foo")
})
test_that("insertInto() on a registered table", {
- df <- read.df(jsonPath, "json")
- write.df(df, parquetPath, "parquet", "overwrite")
- dfParquet <- read.df(parquetPath, "parquet")
-
- lines <- c("{\"name\":\"Bob\", \"age\":24}",
- "{\"name\":\"James\", \"age\":35}")
- jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp")
- parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
- writeLines(lines, jsonPath2)
- df2 <- read.df(jsonPath2, "json")
- write.df(df2, parquetPath2, "parquet", "overwrite")
- dfParquet2 <- read.df(parquetPath2, "parquet")
-
- createOrReplaceTempView(dfParquet, "table1")
- insertInto(dfParquet2, "table1")
- expect_equal(count(sql("select * from table1")), 5)
- expect_equal(first(sql("select * from table1 order by age"))$name, "Michael")
- expect_true(dropTempView("table1"))
-
- createOrReplaceTempView(dfParquet, "table1")
- insertInto(dfParquet2, "table1", overwrite = TRUE)
- expect_equal(count(sql("select * from table1")), 2)
- expect_equal(first(sql("select * from table1 order by age"))$name, "Bob")
- expect_true(dropTempView("table1"))
-
- unlink(jsonPath2)
- unlink(parquetPath2)
+ if (windows_with_hadoop()) {
+ df <- read.df(jsonPath, "json")
+ write.df(df, parquetPath, "parquet", "overwrite")
+ dfParquet <- read.df(parquetPath, "parquet")
+
+ lines <- c("{\"name\":\"Bob\", \"age\":24}",
+ "{\"name\":\"James\", \"age\":35}")
+ jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp")
+ parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
+ writeLines(lines, jsonPath2)
+ df2 <- read.df(jsonPath2, "json")
+ write.df(df2, parquetPath2, "parquet", "overwrite")
+ dfParquet2 <- read.df(parquetPath2, "parquet")
+
+ createOrReplaceTempView(dfParquet, "table1")
+ insertInto(dfParquet2, "table1")
+ expect_equal(count(sql("select * from table1")), 5)
+ expect_equal(first(sql("select * from table1 order by age"))$name, "Michael")
+ expect_true(dropTempView("table1"))
+
+ createOrReplaceTempView(dfParquet, "table1")
+ insertInto(dfParquet2, "table1", overwrite = TRUE)
+ expect_equal(count(sql("select * from table1")), 2)
+ expect_equal(first(sql("select * from table1 order by age"))$name, "Bob")
+ expect_true(dropTempView("table1"))
+
+ unlink(jsonPath2)
+ unlink(parquetPath2)
+ }
})
test_that("tableToDF() returns a new DataFrame", {
@@ -902,14 +925,16 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame",
})
test_that("setCheckpointDir(), checkpoint() on a DataFrame", {
- checkpointDir <- file.path(tempdir(), "cproot")
- expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0)
-
- setCheckpointDir(checkpointDir)
- df <- read.json(jsonPath)
- df <- checkpoint(df)
- expect_is(df, "SparkDataFrame")
- expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0)
+ if (windows_with_hadoop()) {
+ checkpointDir <- file.path(tempdir(), "cproot")
+ expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0)
+
+ setCheckpointDir(checkpointDir)
+ df <- read.json(jsonPath)
+ df <- checkpoint(df)
+ expect_is(df, "SparkDataFrame")
+ expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0)
+ }
})
test_that("schema(), dtypes(), columns(), names() return the correct values/format", {
@@ -1267,45 +1292,47 @@ test_that("column calculation", {
})
test_that("test HiveContext", {
- setHiveContext(sc)
-
- schema <- structType(structField("name", "string"), structField("age", "integer"),
- structField("height", "float"))
- createTable("people", source = "json", schema = schema)
- df <- read.df(jsonPathNa, "json", schema)
- insertInto(df, "people")
- expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16))
- sql("DROP TABLE people")
-
- df <- createTable("json", jsonPath, "json")
- expect_is(df, "SparkDataFrame")
- expect_equal(count(df), 3)
- df2 <- sql("select * from json")
- expect_is(df2, "SparkDataFrame")
- expect_equal(count(df2), 3)
-
- jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
- saveAsTable(df, "json2", "json", "append", path = jsonPath2)
- df3 <- sql("select * from json2")
- expect_is(df3, "SparkDataFrame")
- expect_equal(count(df3), 3)
- unlink(jsonPath2)
-
- hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
- saveAsTable(df, "hivetestbl", path = hivetestDataPath)
- df4 <- sql("select * from hivetestbl")
- expect_is(df4, "SparkDataFrame")
- expect_equal(count(df4), 3)
- unlink(hivetestDataPath)
-
- parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
- saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)
- df5 <- sql("select * from parquetest")
- expect_is(df5, "SparkDataFrame")
- expect_equal(count(df5), 3)
- unlink(parquetDataPath)
-
- unsetHiveContext()
+ if (windows_with_hadoop()) {
+ setHiveContext(sc)
+
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ createTable("people", source = "json", schema = schema)
+ df <- read.df(jsonPathNa, "json", schema)
+ insertInto(df, "people")
+ expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16))
+ sql("DROP TABLE people")
+
+ df <- createTable("json", jsonPath, "json")
+ expect_is(df, "SparkDataFrame")
+ expect_equal(count(df), 3)
+ df2 <- sql("select * from json")
+ expect_is(df2, "SparkDataFrame")
+ expect_equal(count(df2), 3)
+
+ jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
+ saveAsTable(df, "json2", "json", "append", path = jsonPath2)
+ df3 <- sql("select * from json2")
+ expect_is(df3, "SparkDataFrame")
+ expect_equal(count(df3), 3)
+ unlink(jsonPath2)
+
+ hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
+ saveAsTable(df, "hivetestbl", path = hivetestDataPath)
+ df4 <- sql("select * from hivetestbl")
+ expect_is(df4, "SparkDataFrame")
+ expect_equal(count(df4), 3)
+ unlink(hivetestDataPath)
+
+ parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
+ saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)
+ df5 <- sql("select * from parquetest")
+ expect_is(df5, "SparkDataFrame")
+ expect_equal(count(df5), 3)
+ unlink(parquetDataPath)
+
+ unsetHiveContext()
+ }
})
test_that("column operators", {
@@ -1890,6 +1917,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", {
unlink(jsonPath2)
unlink(jsonPath3)
+
+ # Join with broadcast hint
+ df1 <- sql("SELECT * FROM range(10e10)")
+ df2 <- sql("SELECT * FROM range(10e10)")
+
+ execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id)))
+ expect_false(any(grepl("BroadcastHashJoin", execution_plan)))
+
+ execution_plan_hint <- capture.output(
+ explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id))
+ )
+ expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint)))
})
test_that("toJSON() on DataFrame", {
@@ -2085,34 +2124,36 @@ test_that("read/write ORC files - compression option", {
})
test_that("read/write Parquet files", {
- df <- read.df(jsonPath, "json")
- # Test write.df and read.df
- write.df(df, parquetPath, "parquet", mode = "overwrite")
- df2 <- read.df(parquetPath, "parquet")
- expect_is(df2, "SparkDataFrame")
- expect_equal(count(df2), 3)
-
- # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile
- parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
- write.parquet(df, parquetPath2)
- parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet")
- suppressWarnings(saveAsParquetFile(df, parquetPath3))
- parquetDF <- read.parquet(c(parquetPath2, parquetPath3))
- expect_is(parquetDF, "SparkDataFrame")
- expect_equal(count(parquetDF), count(df) * 2)
- parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3))
- expect_is(parquetDF2, "SparkDataFrame")
- expect_equal(count(parquetDF2), count(df) * 2)
-
- # Test if varargs works with variables
- saveMode <- "overwrite"
- mergeSchema <- "true"
- parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet")
- write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema)
-
- unlink(parquetPath2)
- unlink(parquetPath3)
- unlink(parquetPath4)
+ if (windows_with_hadoop()) {
+ df <- read.df(jsonPath, "json")
+ # Test write.df and read.df
+ write.df(df, parquetPath, "parquet", mode = "overwrite")
+ df2 <- read.df(parquetPath, "parquet")
+ expect_is(df2, "SparkDataFrame")
+ expect_equal(count(df2), 3)
+
+ # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile
+ parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
+ write.parquet(df, parquetPath2)
+ parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet")
+ suppressWarnings(saveAsParquetFile(df, parquetPath3))
+ parquetDF <- read.parquet(c(parquetPath2, parquetPath3))
+ expect_is(parquetDF, "SparkDataFrame")
+ expect_equal(count(parquetDF), count(df) * 2)
+ parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3))
+ expect_is(parquetDF2, "SparkDataFrame")
+ expect_equal(count(parquetDF2), count(df) * 2)
+
+ # Test if varargs works with variables
+ saveMode <- "overwrite"
+ mergeSchema <- "true"
+ parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet")
+ write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema)
+
+ unlink(parquetPath2)
+ unlink(parquetPath3)
+ unlink(parquetPath4)
+ }
})
test_that("read/write Parquet files - compression option/mode", {
@@ -2617,7 +2658,6 @@ test_that("dapply() and dapplyCollect() on a DataFrame", {
})
test_that("dapplyCollect() on DataFrame with a binary column", {
-
df <- data.frame(key = 1:3)
df$bytes <- lapply(df$key, serialize, connection = NULL)
@@ -2706,6 +2746,11 @@ test_that("gapply() and gapplyCollect() on a DataFrame", {
df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x })
expect_identical(df1Collect, expected)
+ # gapply on empty grouping columns.
+ df1 <- gapply(df, c(), function(key, x) { x }, schema(df))
+ actual <- collect(df1)
+ expect_identical(actual, expected)
+
# Computes the sum of second column by grouping on the first and third columns
# and checks if the sum is larger than 2
schema <- structType(structField("a", "integer"), structField("e", "boolean"))
diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R
similarity index 93%
rename from R/pkg/inst/tests/testthat/test_streaming.R
rename to R/pkg/tests/fulltests/test_streaming.R
index 03b1bd3dc1f44..d691de7cd725d 100644
--- a/R/pkg/inst/tests/testthat/test_streaming.R
+++ b/R/pkg/tests/fulltests/test_streaming.R
@@ -21,10 +21,10 @@ context("Structured Streaming")
# Tests for Structured Streaming functions in SparkR
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
jsonSubDir <- file.path("sparkr-test", "json", "")
-if (.Platform$OS.type == "windows") {
+if (is_windows()) {
# file.path removes the empty separator on Windows, adds it back
jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep)
}
@@ -53,14 +53,17 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", {
q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete")
expect_false(awaitTermination(q, 5 * 1000))
+ callJMethod(q@ssq, "processAllAvailable")
expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3)
writeLines(mockLinesNa, jsonPathNa)
awaitTermination(q, 5 * 1000)
+ callJMethod(q@ssq, "processAllAvailable")
expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6)
stopQuery(q)
expect_true(awaitTermination(q, 1))
+ expect_error(awaitTermination(q), NA)
})
test_that("print from explain, lastProgress, status, isActive", {
@@ -70,6 +73,7 @@ test_that("print from explain, lastProgress, status, isActive", {
q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete")
awaitTermination(q, 5 * 1000)
+ callJMethod(q@ssq, "processAllAvailable")
expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==")
expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q)))))
@@ -92,6 +96,7 @@ test_that("Stream other format", {
q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete")
expect_false(awaitTermination(q, 5 * 1000))
+ callJMethod(q@ssq, "processAllAvailable")
expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3)
expect_equal(queryName(q), "people3")
@@ -131,7 +136,7 @@ test_that("Terminated by error", {
expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"),
NA)
- expect_error(awaitTermination(q, 1),
+ expect_error(awaitTermination(q, 5 * 1000),
paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option",
" 'maxFilesPerTrigger', must be a positive integer).*"))
diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/tests/fulltests/test_take.R
similarity index 97%
rename from R/pkg/inst/tests/testthat/test_take.R
rename to R/pkg/tests/fulltests/test_take.R
index aaa532856c3d9..8936cc57da227 100644
--- a/R/pkg/inst/tests/testthat/test_take.R
+++ b/R/pkg/tests/fulltests/test_take.R
@@ -30,7 +30,7 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ",
"raising me. But they're both dead now. I didn't kill them. Honest.")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
test_that("take() gives back the original elements in correct count and order", {
diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R
similarity index 98%
rename from R/pkg/inst/tests/testthat/test_textFile.R
rename to R/pkg/tests/fulltests/test_textFile.R
index 3b466066e9390..be2d2711ff88e 100644
--- a/R/pkg/inst/tests/testthat/test_textFile.R
+++ b/R/pkg/tests/fulltests/test_textFile.R
@@ -18,7 +18,7 @@
context("the textFile() function")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockFile <- c("Spark is pretty.", "Spark is awesome.")
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/tests/fulltests/test_utils.R
similarity index 88%
rename from R/pkg/inst/tests/testthat/test_utils.R
rename to R/pkg/tests/fulltests/test_utils.R
index 6d006eccf665e..50fc6f3ee9815 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/tests/fulltests/test_utils.R
@@ -18,7 +18,7 @@
context("functions in utils.R")
# JavaSparkContext handle
-sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
test_that("convertJListToRList() gives back (deserializes) the original JLists
@@ -236,4 +236,29 @@ test_that("basenameSansExtFromUrl", {
expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive")
})
+test_that("getOne", {
+ dummy <- getOne(".dummyValue", envir = new.env(), ifnotfound = FALSE)
+ expect_equal(dummy, FALSE)
+})
+
+test_that("traverseParentDirs", {
+ if (is_windows()) {
+ # original path is included as-is, otherwise dirname() replaces \\ with / on windows
+ dirs <- traverseParentDirs("c:\\Users\\user\\AppData\\Local\\Apache\\Spark\\Cache\\spark2.2", 3)
+ expect <- c("c:\\Users\\user\\AppData\\Local\\Apache\\Spark\\Cache\\spark2.2",
+ "c:/Users/user/AppData/Local/Apache/Spark/Cache",
+ "c:/Users/user/AppData/Local/Apache/Spark",
+ "c:/Users/user/AppData/Local/Apache")
+ expect_equal(dirs, expect)
+ } else {
+ dirs <- traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1)
+ expect <- c("/Users/user/Library/Caches/spark/spark2.2", "/Users/user/Library/Caches/spark")
+ expect_equal(dirs, expect)
+
+ dirs <- traverseParentDirs("/home/u/.cache/spark/spark2.2", 1)
+ expect <- c("/home/u/.cache/spark/spark2.2", "/home/u/.cache/spark")
+ expect_equal(dirs, expect)
+ }
+})
+
sparkR.session.stop()
diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R
index 29812f872c784..1ff9ca90cafe9 100644
--- a/R/pkg/tests/run-all.R
+++ b/R/pkg/tests/run-all.R
@@ -21,14 +21,45 @@ library(SparkR)
# Turn all warnings into errors
options("warn" = 2)
+if (.Platform$OS.type == "windows") {
+ Sys.setenv(TZ = "GMT")
+}
+
# Setup global test environment
# Install Spark first to set SPARK_HOME
-install.spark()
+
+# NOTE(shivaram): We set overwrite to handle any old tar.gz files or directories left behind on
+# CRAN machines. For Jenkins we should already have SPARK_HOME set.
+install.spark(overwrite = TRUE)
sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R")
-sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE)
sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db")
invisible(lapply(sparkRWhitelistSQLDirs,
function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)}))
+sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE)
+
+sparkRTestMaster <- "local[1]"
+sparkRTestConfig <- list()
+if (identical(Sys.getenv("NOT_CRAN"), "true")) {
+ sparkRTestMaster <- ""
+} else {
+ # Disable hsperfdata on CRAN
+ old_java_opt <- Sys.getenv("_JAVA_OPTIONS")
+ Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt))
+ tmpDir <- tempdir()
+ tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir)
+ sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg,
+ spark.executor.extraJavaOptions = tmpArg)
+}
test_package("SparkR")
+
+if (identical(Sys.getenv("NOT_CRAN"), "true")) {
+ # for testthat 1.0.2 later, change reporter from "summary" to default_reporter()
+ testthat:::run_tests("SparkR",
+ file.path(sparkRDir, "pkg", "tests", "fulltests"),
+ NULL,
+ "summary")
+}
+
+SparkR:::uninstallDownloadedSpark()
diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd
index a6ff650c33fea..ad660ec8b871b 100644
--- a/R/pkg/vignettes/sparkr-vignettes.Rmd
+++ b/R/pkg/vignettes/sparkr-vignettes.Rmd
@@ -27,6 +27,23 @@ vignette: >
limitations under the License.
-->
+```{r setup, include=FALSE}
+library(knitr)
+opts_hooks$set(eval = function(options) {
+ # override eval to FALSE only on windows
+ if (.Platform$OS.type == "windows") {
+ options$eval = FALSE
+ }
+ options
+})
+r_tmp_dir <- tempdir()
+tmp_arg <- paste0("-Djava.io.tmpdir=", r_tmp_dir)
+sparkSessionConfig <- list(spark.driver.extraJavaOptions = tmp_arg,
+ spark.executor.extraJavaOptions = tmp_arg)
+old_java_opt <- Sys.getenv("_JAVA_OPTIONS")
+Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " "))
+```
+
## Overview
SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/).
@@ -46,8 +63,9 @@ We use default settings in which it runs in local mode. It auto downloads Spark
```{r, include=FALSE}
install.spark()
+sparkR.session(master = "local[1]", sparkConfig = sparkSessionConfig, enableHiveSupport = FALSE)
```
-```{r, message=FALSE, results="hide"}
+```{r, eval=FALSE}
sparkR.session()
```
@@ -65,7 +83,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun
head(carsDF)
```
-Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`.
+Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`.
```{r}
carsSubDF <- select(carsDF, "model", "mpg", "hp")
carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200)
@@ -182,7 +200,7 @@ head(df)
```
### Data Sources
-SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources.
+SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources.
The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`.
@@ -232,7 +250,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite"
```
### Hive Tables
-You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`).
+You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`).
```{r, eval=FALSE}
sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
@@ -364,7 +382,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema)
head(collect(out))
```
-Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory.
+Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory.
```{r}
out <- dapplyCollect(
@@ -390,7 +408,7 @@ result <- gapply(
head(arrange(result, "max_mpg", decreasing = TRUE))
```
-Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory.
+Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory.
```{r}
result <- gapplyCollect(
@@ -443,20 +461,20 @@ options(ops)
### SQL Queries
-A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`.
+A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`.
```{r}
people <- read.df(paste0(sparkR.conf("spark.home"),
"/examples/src/main/resources/people.json"), "json")
```
-Register this SparkDataFrame as a temporary view.
+Register this `SparkDataFrame` as a temporary view.
```{r}
createOrReplaceTempView(people, "people")
```
-SQL statements can be run by using the sql method.
+SQL statements can be run using the sql method.
```{r}
teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
head(teenagers)
@@ -505,6 +523,10 @@ SparkR supports the following machine learning models and algorithms.
* Alternating Least Squares (ALS)
+#### Frequent Pattern Mining
+
+* FP-growth
+
#### Statistics
* Kolmogorov-Smirnov Test
@@ -653,6 +675,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction
Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring.
Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently.
+
```{r, warning=FALSE}
library(survival)
ovarianDF <- createDataFrame(ovarian)
@@ -707,7 +730,7 @@ summary(tweedieGLM1)
```
We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link:
```{r}
-tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie",
+tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie",
var.power = 1.2, link.power = 0.0)
summary(tweedieGLM2)
```
@@ -760,7 +783,7 @@ head(predict(isoregModel, newDF))
`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`.
Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models.
-Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions:
+We use the `longley` dataset to train a gradient-boosted tree and make predictions:
```{r, warning=FALSE}
df <- createDataFrame(longley)
@@ -800,7 +823,7 @@ head(select(fitted, "Class", "prediction"))
`spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model.
-We use a simulated example to demostrate the usage.
+We use a simulated example to demonstrate the usage.
```{r}
X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4))
X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4))
@@ -831,9 +854,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20
* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words).
-* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated.
+* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated.
-To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column:
+To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column:
* character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`.
@@ -881,9 +904,9 @@ perplexity
`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614).
-There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file.
+There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file.
-```{r}
+```{r, eval=FALSE}
ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
list(2, 1, 1.0), list(2, 2, 5.0))
df <- createDataFrame(ratings, c("user", "item", "rating"))
@@ -891,7 +914,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati
```
Extract latent factors.
-```{r}
+```{r, eval=FALSE}
stats <- summary(model)
userFactors <- stats$userFactors
itemFactors <- stats$itemFactors
@@ -901,11 +924,42 @@ head(itemFactors)
Make predictions.
-```{r}
+```{r, eval=FALSE}
predicted <- predict(model, df)
head(predicted)
```
+#### FP-growth
+
+`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values.
+
+```{r}
+df <- selectExpr(createDataFrame(data.frame(rawItems = c(
+ "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S",
+ "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U"
+))), "split(rawItems, ',') AS items")
+
+fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5)
+```
+
+`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets.
+
+```{r}
+head(spark.freqItemsets(fpm))
+```
+
+`spark.associationRules` returns a `SparkDataFrame` with the association rules.
+
+```{r}
+head(spark.associationRules(fpm))
+```
+
+We can make predictions based on the `antecedent`.
+
+```{r}
+head(predict(fpm, df))
+```
+
#### Kolmogorov-Smirnov Test
`spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test).
@@ -930,7 +984,7 @@ testSummary
### Model Persistence
-The following example shows how to save/load an ML model by SparkR.
+The following example shows how to save/load an ML model in SparkR.
```{r}
t <- as.data.frame(Titanic)
training <- createDataFrame(t)
@@ -952,6 +1006,72 @@ unlink(modelPath)
```
+## Structured Streaming
+
+SparkR supports the Structured Streaming API (experimental).
+
+You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts.
+
+### Simple Source and Sink
+
+Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts:
+
+```{r, eval=FALSE}
+# Create DataFrame representing the stream of input lines from connection
+lines <- read.stream("socket", host = hostname, port = port)
+
+# Split the lines into words
+words <- selectExpr(lines, "explode(split(value, ' ')) as word")
+
+# Generate running word count
+wordCounts <- count(groupBy(words, "word"))
+
+# Start running the query that prints the running counts to the console
+query <- write.stream(wordCounts, "console", outputMode = "complete")
+```
+
+### Kafka Source
+
+It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming.
+
+```{r, eval=FALSE}
+topic <- read.stream("kafka",
+ kafka.bootstrap.servers = "host1:port1,host2:port2",
+ subscribe = "topic1")
+keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)")
+```
+
+### Operations and Sinks
+
+Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`.
+
+A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats.
+
+```{r, eval=FALSE}
+noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device")
+
+# Print new data to console
+write.stream(noAggDF, "console")
+
+# Write new data to Parquet files
+write.stream(noAggDF,
+ "parquet",
+ path = "path/to/destination/dir",
+ checkpointLocation = "path/to/checkpoint/dir")
+
+# Aggregate
+aggDF <- count(groupBy(noAggDF, "device"))
+
+# Print updated aggregations to console
+write.stream(aggDF, "console", outputMode = "complete")
+
+# Have all the aggregates in an in memory table. The query name will be the table name
+write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete")
+
+head(sql("select * from aggregates"))
+```
+
+
## Advanced Topics
### SparkR Object Classes
@@ -962,19 +1082,19 @@ There are three main object classes in SparkR you may be working with.
+ `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend.
+ `env` saves the meta-information of the object such as `isCached`.
-It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms.
+ It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms.
-* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend.
+* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend.
-It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group.
+ It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group.
-* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend.
+* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend.
-This is often an intermediate object with group information and followed up by aggregation operations.
+ This is often an intermediate object with group information and followed up by aggregation operations.
### Architecture
-A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*.
+A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*.
Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation.
@@ -1028,3 +1148,7 @@ env | map
```{r, echo=FALSE}
sparkR.session.stop()
```
+
+```{r cleanup, include=FALSE}
+SparkR:::uninstallDownloadedSpark()
+```
diff --git a/R/run-tests.sh b/R/run-tests.sh
index 742a2c5ed76da..86bd8aad5f113 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE
-SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))
NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)"
@@ -38,6 +38,7 @@ FAILED=$((PIPESTATUS[0]||$FAILED))
NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)"
NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)"
NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)"
+HAS_PACKAGE_VERSION_WARN="$(grep -c "Insufficient package version" $CRAN_CHECK_LOG_FILE)"
if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then
cat $LOGFILE
@@ -46,9 +47,10 @@ if [[ $FAILED != 0 || $NUM_TEST_WARNING != 0 ]]; then
echo -en "\033[0m" # No color
exit -1
else
- # We have 2 existing NOTEs for new maintainer, attach()
- # We have one more NOTE in Jenkins due to "No repository set"
- if [[ $NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3 ]]; then
+ # We have 2 NOTEs: for RoxygenNote and one in Jenkins only "No repository set"
+ # For non-latest version branches, one WARNING for package version
+ if [[ ($NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 2) &&
+ ($HAS_PACKAGE_VERSION_WARN != 1 || $NUM_CRAN_WARNING != 1 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 1) ]]; then
cat $CRAN_CHECK_LOG_FILE
echo -en "\033[31m" # Red
echo "Had CRAN check errors; see logs."
diff --git a/appveyor.yml b/appveyor.yml
index bbb27589cad09..c7660f115c538 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -26,10 +26,14 @@ branches:
only_commits:
files:
+ - appveyor.yml
+ - dev/appveyor-install-dependencies.ps1
- R/
- sql/core/src/main/scala/org/apache/spark/sql/api/r/
- core/src/main/scala/org/apache/spark/api/r/
- mllib/src/main/scala/org/apache/spark/ml/r/
+ - core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+ - bin/*.cmd
cache:
- C:\Users\appveyor\.m2
@@ -38,16 +42,17 @@ install:
# Install maven and dependencies
- ps: .\dev\appveyor-install-dependencies.ps1
# Required package for R unit tests
- - cmd: R -e "install.packages('testthat', repos='http://cran.us.r-project.org')"
- - cmd: R -e "packageVersion('testthat')"
- - cmd: R -e "install.packages('e1071', repos='http://cran.us.r-project.org')"
- - cmd: R -e "packageVersion('e1071')"
- - cmd: R -e "install.packages('survival', repos='http://cran.us.r-project.org')"
- - cmd: R -e "packageVersion('survival')"
+ - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'devtools', 'e1071', 'survival'), repos='http://cran.us.r-project.org')"
+ # Here, we use the fixed version of testthat. For more details, please see SPARK-22817.
+ - cmd: R -e "devtools::install_version('testthat', version = '1.0.2', repos='http://cran.us.r-project.org')"
+ - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')"
build_script:
- cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package
+environment:
+ NOT_CRAN: true
+
test_script:
- cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R
@@ -56,4 +61,3 @@ notifications:
on_build_success: false
on_build_failure: false
on_build_status_changed: false
-
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 9d8607d9137c6..f9ec6e7617607 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../pom.xml
diff --git a/bin/find-spark-home.cmd b/bin/find-spark-home.cmd
new file mode 100644
index 0000000000000..6025f67c38de4
--- /dev/null
+++ b/bin/find-spark-home.cmd
@@ -0,0 +1,60 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem Path to Python script finding SPARK_HOME
+set FIND_SPARK_HOME_PYTHON_SCRIPT=%~dp0find_spark_home.py
+
+rem Default to standard python interpreter unless told otherwise
+set PYTHON_RUNNER=python
+rem If PYSPARK_DRIVER_PYTHON is set, it overwrites the python version
+if not "x%PYSPARK_DRIVER_PYTHON%"=="x" (
+ set PYTHON_RUNNER=%PYSPARK_DRIVER_PYTHON%
+)
+rem If PYSPARK_PYTHON is set, it overwrites the python version
+if not "x%PYSPARK_PYTHON%"=="x" (
+ set PYTHON_RUNNER=%PYSPARK_PYTHON%
+)
+
+rem If there is python installed, trying to use the root dir as SPARK_HOME
+where %PYTHON_RUNNER% > nul 2>&1
+if %ERRORLEVEL% neq 0 (
+ if not exist %PYTHON_RUNNER% (
+ if "x%SPARK_HOME%"=="x" (
+ echo Missing Python executable '%PYTHON_RUNNER%', defaulting to '%~dp0..' for SPARK_HOME ^
+environment variable. Please install Python or specify the correct Python executable in ^
+PYSPARK_DRIVER_PYTHON or PYSPARK_PYTHON environment variable to detect SPARK_HOME safely.
+ set SPARK_HOME=%~dp0..
+ )
+ )
+)
+
+rem Only attempt to find SPARK_HOME if it is not set.
+if "x%SPARK_HOME%"=="x" (
+ if not exist "%FIND_SPARK_HOME_PYTHON_SCRIPT%" (
+ rem If we are not in the same directory as find_spark_home.py we are not pip installed so we don't
+ rem need to search the different Python directories for a Spark installation.
+ rem Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or
+ rem spark-submit in another directory we want to use that version of PySpark rather than the
+ rem pip installed version of PySpark.
+ set SPARK_HOME=%~dp0..
+ ) else (
+ rem We are pip installed, use the Python script to resolve a reasonable SPARK_HOME
+ for /f "delims=" %%i in ('%PYTHON_RUNNER% %FIND_SPARK_HOME_PYTHON_SCRIPT%') do set SPARK_HOME=%%i
+ )
+)
diff --git a/bin/pyspark b/bin/pyspark
index 98387c2ec5b8a..95ab62880654f 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
-# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
+# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
# and executor Python executables.
# Fail noisily if removed options are set
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
- echo "Error in pyspark startup:"
+ echo "Error in pyspark startup:"
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
exit 1
fi
@@ -57,7 +57,7 @@ export PYSPARK_PYTHON
# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
-export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH"
+export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"
# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index f211c0873ad2f..15fa910c277b3 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -18,7 +18,7 @@ rem limitations under the License.
rem
rem Figure out where the Spark framework is installed
-set SPARK_HOME=%~dp0..
+call "%~dp0find-spark-home.cmd"
call "%SPARK_HOME%\bin\load-spark-env.cmd"
set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options]
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
-set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH%
+set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
diff --git a/bin/run-example.cmd b/bin/run-example.cmd
index f9b786e92b823..7cfaa7e996e89 100644
--- a/bin/run-example.cmd
+++ b/bin/run-example.cmd
@@ -17,6 +17,8 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SPARK_HOME=%~dp0..
+rem Figure out where the Spark framework is installed
+call "%~dp0find-spark-home.cmd"
+
set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args]
cmd /V /E /C "%~dp0spark-submit.cmd" run-example %*
diff --git a/bin/spark-class b/bin/spark-class
index 77ea40cc37946..65d3b9612909a 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -72,6 +72,8 @@ build_command() {
printf "%d\0" $?
}
+# Turn off posix mode since it does not allow process substitution
+set +o posix
CMD=()
while IFS= read -d '' -r ARG; do
CMD+=("$ARG")
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 9faa7d65f83e4..5da7d7a430d79 100644
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -18,7 +18,7 @@ rem limitations under the License.
rem
rem Figure out where the Spark framework is installed
-set SPARK_HOME=%~dp0..
+call "%~dp0find-spark-home.cmd"
call "%SPARK_HOME%\bin\load-spark-env.cmd"
@@ -29,7 +29,7 @@ if "x%1"=="x" (
)
rem Find Spark jars.
-if exist "%SPARK_HOME%\RELEASE" (
+if exist "%SPARK_HOME%\jars" (
set SPARK_JARS_DIR="%SPARK_HOME%\jars"
) else (
set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%\jars"
@@ -51,7 +51,7 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" (
rem Figure out where java is.
set RUNNER=java
if not "x%JAVA_HOME%"=="x" (
- set RUNNER="%JAVA_HOME%\bin\java"
+ set RUNNER=%JAVA_HOME%\bin\java
) else (
where /q "%RUNNER%"
if ERRORLEVEL 1 (
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 7b5d396be888c..aaf71906c6526 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -17,7 +17,9 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SPARK_HOME=%~dp0..
+rem Figure out where the Spark framework is installed
+call "%~dp0find-spark-home.cmd"
+
set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options]
rem SPARK-4161: scala does not assume use of the java classpath,
diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd
index 459b780e2ae33..b48bea345c0b9 100644
--- a/bin/sparkR2.cmd
+++ b/bin/sparkR2.cmd
@@ -18,7 +18,7 @@ rem limitations under the License.
rem
rem Figure out where the Spark framework is installed
-set SPARK_HOME=%~dp0..
+call "%~dp0find-spark-home.cmd"
call "%SPARK_HOME%\bin\load-spark-env.cmd"
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 8657af744c069..55d29d5729e0a 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../../pom.xml
@@ -90,7 +90,8 @@
org.apache.spark
spark-tags_${scala.binary.version}
-
+ test
+
1.7.7
hadoop2
0.9.3
@@ -173,7 +170,7 @@
3.5
3.2.10
- 3.0.0
+ 3.0.8
2.22.2
2.9.3
3.5.2
@@ -226,7 +223,7 @@
central
Maven Repository
- https://repo1.maven.org/maven2
+ https://repo.maven.apache.org/maven2
true
@@ -238,7 +235,7 @@
central
- https://repo1.maven.org/maven2
+ https://repo.maven.apache.org/maven2
true
@@ -513,7 +510,6 @@
org.xerial.snappy
snappy-java
${snappy.version}
- ${hadoop.deps.scope}
net.jpountz.lz4
@@ -659,7 +655,7 @@
org.scalanlp
breeze_${scala.binary.version}
- 0.12
+ 0.13.2
@@ -2010,7 +2006,7 @@
**/*Suite.java
${project.build.directory}/surefire-reports
- -Xmx3g -Xss4096k -XX:ReservedCodeCacheSize=${CodeCacheSize}
+ -ea -Xmx3g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
JavaConversions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 765c92b8d3b9e..7e4e4073ece57 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../../pom.xml
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 1ecb3d1958f43..f40412f0a6ee9 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -371,7 +371,7 @@ querySpecification
(RECORDREADER recordReader=STRING)?
fromClause?
(WHERE where=booleanExpression)?)
- | ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause?
+ | ((kind=SELECT (hints+=hint)* setQuantifier? namedExpressionSeq fromClause?
| fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?)
lateralView*
(WHERE where=booleanExpression)?
@@ -381,12 +381,12 @@ querySpecification
;
hint
- : '/*+' hintStatement '*/'
+ : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/'
;
hintStatement
: hintName=identifier
- | hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')'
+ | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')'
;
fromClause
@@ -548,10 +548,10 @@ valueExpression
;
primaryExpression
- : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall
- | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ : CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| CAST '(' expression AS dataType ')' #cast
+ | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct
| FIRST '(' expression (IGNORE NULLS)? ')' #first
| LAST '(' expression (IGNORE NULLS)? ')' #last
| constant #constantDefault
@@ -559,7 +559,7 @@ primaryExpression
| qualifiedName '.' ASTERISK #star
| '(' namedExpression (',' namedExpression)+ ')' #rowConstructor
| '(' query ')' #subqueryExpression
- | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')'
+ | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
(OVER windowSpec)? #functionCall
| value=primaryExpression '[' index=valueExpression ']' #subscript
| identifier #columnReference
@@ -726,7 +726,7 @@ nonReserved
| NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE
| AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN
| UNBOUNDED | WHEN
- | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP
+ | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT
;
SELECT: 'SELECT';
@@ -954,8 +954,6 @@ OPTION: 'OPTION';
ANTI: 'ANTI';
LOCAL: 'LOCAL';
INPATH: 'INPATH';
-CURRENT_DATE: 'CURRENT_DATE';
-CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP';
STRING
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 86de90984ca00..56994fafe064b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -550,7 +550,7 @@ public void copyFrom(UnsafeRow row) {
*/
public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException {
if (baseObject instanceof byte[]) {
- int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset);
+ int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET);
out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes);
} else {
int dataRemaining = sizeInBytes;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java
index bd5e2d7ecca9b..5f1032d1229da 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java
@@ -37,7 +37,9 @@ public class GroupStateTimeout {
* `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation
* on `GroupState` for more details.
*/
- public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; }
+ public static GroupStateTimeout ProcessingTimeTimeout() {
+ return ProcessingTimeTimeout$.MODULE$;
+ }
/**
* Timeout based on event-time. The event-time timestamp for timeout can be set for each
@@ -51,4 +53,5 @@ public class GroupStateTimeout {
/** No timeout. */
public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
+
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java
index 3f7cdb293e0fa..2800b3068f87b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java
@@ -17,19 +17,15 @@
package org.apache.spark.sql.streaming;
-import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes;
/**
- * :: Experimental ::
- *
* OutputMode is used to what data will be written to a streaming sink when there is
* new data available in a streaming DataFrame/Dataset.
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
public class OutputMode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d4ebdb139fe0f..474ec592201d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -310,7 +310,7 @@ object CatalystTypeConverters {
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
- decimal.toPrecision(dataType.precision, dataType.scale).orNull
+ decimal.toPrecision(dataType.precision, dataType.scale)
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 86a73a319ec3f..2698faef76902 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -423,6 +423,7 @@ object JavaTypeInference {
inputObject,
ObjectType(keyType.getRawType),
serializerFor(_, keyType),
+ keyNullable = true,
ObjectType(valueType.getRawType),
serializerFor(_, valueType),
valueNullable = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 82710a2a183ab..7f727515036d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -62,7 +62,7 @@ object ScalaReflection extends ScalaReflection {
*/
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
- private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
+ private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
tpe match {
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
@@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection {
* Array[T]. Special handling is performed for primitive types to map them back to their raw
* JVM form instead of the Scala Array that handles auto boxing.
*/
- private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
+ private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
val cls = tpe match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
@@ -133,19 +133,25 @@ object ScalaReflection extends ScalaReflection {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
- deserializerFor(tpe, None, walkedTypePath)
+ val expr = deserializerFor(tpe, None, walkedTypePath)
+ val Schema(_, nullable) = schemaFor(tpe)
+ if (nullable) {
+ expr
+ } else {
+ AssertNotNull(expr, walkedTypePath)
+ }
}
private def deserializerFor(
tpe: `Type`,
path: Option[Expression],
- walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
+ walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
- .getOrElse(UnresolvedAttribute(part))
+ .getOrElse(UnresolvedAttribute.quoted(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
@@ -446,7 +452,7 @@ object ScalaReflection extends ScalaReflection {
inputObject: Expression,
tpe: `Type`,
walkedTypePath: Seq[String],
- seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized {
+ seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects {
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
dataTypeFor(elementType) match {
@@ -511,6 +517,7 @@ object ScalaReflection extends ScalaReflection {
inputObject,
dataTypeFor(keyType),
serializerFor(_, keyType, keyPath, seenTypeSet),
+ keyNullable = !keyType.typeSymbol.asClass.isPrimitive,
dataTypeFor(valueType),
serializerFor(_, valueType, valuePath, seenTypeSet),
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
@@ -631,7 +638,7 @@ object ScalaReflection extends ScalaReflection {
* Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that,
* we also treat [[DefinedByConstructorParams]] as product type.
*/
- def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized {
+ def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
@@ -664,7 +671,7 @@ object ScalaReflection extends ScalaReflection {
val m = runtimeMirror(cls.getClassLoader)
val classSymbol = m.staticClass(cls.getName)
val t = classSymbol.selfType
- constructParams(t).map(_.name.toString)
+ constructParams(t).map(_.name.decodedName.toString)
}
/**
@@ -693,7 +700,7 @@ object ScalaReflection extends ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
+ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects {
tpe match {
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
@@ -759,7 +766,7 @@ object ScalaReflection extends ScalaReflection {
/**
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
- def definedByConstructorParams(tpe: Type): Boolean = {
+ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
}
@@ -788,6 +795,20 @@ trait ScalaReflection {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map
+ /**
+ * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to
+ * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
+ * `scala.reflect.runtime.JavaUniverse.undoLog`.
+ *
+ * This method will also wrap `func` with `ScalaReflectionLock.synchronized` so the caller doesn't
+ * need to call it again.
+ *
+ * @see https://github.com/scala/bug/issues/8302
+ */
+ def cleanUpReflectionObjects[T](func: => T): T = ScalaReflectionLock.synchronized {
+ universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
+ }
+
/**
* Return the Scala Type for `T` in the current classloader mirror.
*
@@ -836,8 +857,17 @@ trait ScalaReflection {
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = tpe
- constructParams(tpe).map { p =>
- p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ val params = constructParams(tpe)
+ // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
+ if (actualTypeArgs.nonEmpty) {
+ params.map { p =>
+ p.name.decodedName.toString ->
+ p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ }
+ } else {
+ params.map { p =>
+ p.name.decodedName.toString -> p.typeSignature
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9816b33ae8dff..5b26a23fd6cf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -136,6 +136,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveOrdinalInOrderByAndGroupBy ::
+ ResolveAggAliasInGroupBy ::
ResolveMissingReferences ::
ExtractGenerator ::
ResolveGenerate ::
@@ -150,6 +151,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
+ ResolveTimeZone(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@@ -161,8 +163,6 @@ class Analyzer(
HandleNullInputsForUDF),
Batch("FixNullability", Once,
FixNullability),
- Batch("ResolveTimeZone", Once,
- ResolveTimeZone),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
@@ -173,7 +173,7 @@ class Analyzer(
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
*/
object CTESubstitution extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case With(child, relations) =>
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) =>
@@ -201,7 +201,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
@@ -243,7 +243,7 @@ class Analyzer(
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
@@ -280,9 +280,15 @@ class Analyzer(
* We need to get all of its subsets for a given GROUPBY expression, the subsets are
* represented as sequence of expressions.
*/
- def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
+ def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = {
+ // `cubeExprs0` is recursive and returns a lazy Stream. Here we call `toIndexedSeq` to
+ // materialize it and avoid serialization problems later on.
+ cubeExprs0(exprs).toIndexedSeq
+ }
+
+ def cubeExprs0(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
case x :: xs =>
- val initial = cubeExprs(xs)
+ val initial = cubeExprs0(xs)
initial.map(x +: _) ++ initial
case Nil =>
Seq(Seq.empty)
@@ -315,7 +321,7 @@ class Analyzer(
s"grouping columns (${groupByExprs.mkString(",")})")
}
case e @ Grouping(col: Expression) =>
- val idx = groupByExprs.indexOf(col)
+ val idx = groupByExprs.indexWhere(_.semanticEquals(col))
if (idx >= 0) {
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1)), ByteType), toPrettySQL(e))()
@@ -615,7 +621,7 @@ class Analyzer(
case _ => plan
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View =>
@@ -787,7 +793,7 @@ class Analyzer(
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it.
@@ -845,11 +851,15 @@ class Analyzer(
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
- q transformExpressionsUp {
+ q.transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
- // Leave unchanged if resolution fails. Hopefully will be resolved next round.
+ // Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
+ withPosition(u) {
+ q.resolveChildren(nameParts, resolver)
+ .orElse(resolveLiteralFunction(nameParts, u, q))
+ .getOrElse(u)
+ }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -927,6 +937,30 @@ class Analyzer(
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
}
+ /**
+ * Literal functions do not require the user to specify braces when calling them
+ * When an attributes is not resolvable, we try to resolve it as a literal function.
+ */
+ private def resolveLiteralFunction(
+ nameParts: Seq[String],
+ attribute: UnresolvedAttribute,
+ plan: LogicalPlan): Option[Expression] = {
+ if (nameParts.length != 1) return None
+ val isNamedExpression = plan match {
+ case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute)
+ case Project(projectList, _) => projectList.contains(attribute)
+ case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute)
+ case _ => false
+ }
+ val wrapper: Expression => Expression =
+ if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity
+ // support CURRENT_DATE and CURRENT_TIMESTAMP
+ val literalFunctions = Seq(CurrentDate(), CurrentTimestamp())
+ val name = nameParts.head
+ val func = literalFunctions.find(e => resolver(e.prettyName, name))
+ func.map(wrapper)
+ }
+
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
@@ -939,7 +973,11 @@ class Analyzer(
expr transformUp {
case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
case u @ UnresolvedAttribute(nameParts) =>
- withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
+ withPosition(u) {
+ plan.resolve(nameParts, resolver)
+ .orElse(resolveLiteralFunction(nameParts, u, plan))
+ .getOrElse(u)
+ }
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -962,11 +1000,11 @@ class Analyzer(
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
- case s @ Sort(orders, global, child)
+ case Sort(orders, global, child)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
@@ -983,17 +1021,11 @@ class Analyzer(
// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
- case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
+ case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map {
- case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
- aggs(index - 1) match {
- case e if ResolveAggregateFunctions.containsAggregate(e) =>
- ordinal.failAnalysis(
- s"GROUP BY position $index is an aggregate function, and " +
- "aggregate functions are not allowed in GROUP BY")
- case o => o
- }
+ case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
+ aggs(index - 1)
case ordinal @ UnresolvedOrdinal(index) =>
ordinal.failAnalysis(
s"GROUP BY position $index is not in select list " +
@@ -1004,6 +1036,41 @@ class Analyzer(
}
}
+ /**
+ * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses.
+ * This rule is expected to run after [[ResolveReferences]] applied.
+ */
+ object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {
+
+ // This is a strict check though, we put this to apply the rule only if the expression is not
+ // resolvable by child.
+ private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = {
+ !child.output.exists(a => resolver(a.name, attrName))
+ }
+
+ private def mayResolveAttrByAggregateExprs(
+ exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = {
+ exprs.map { _.transform {
+ case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
+ aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
+ }}
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ case agg @ Aggregate(groups, aggs, child)
+ if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
+ groups.exists(!_.resolved) =>
+ agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child))
+
+ case gs @ GroupingSets(selectedGroups, groups, child, aggs)
+ if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
+ groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
+ gs.copy(
+ selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)),
+ groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child))
+ }
+ }
+
/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
@@ -1013,7 +1080,7 @@ class Analyzer(
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveMissingReferences extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
@@ -1137,7 +1204,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
@@ -1161,11 +1228,21 @@ class Analyzer(
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
- case wf: AggregateWindowFunction => wf
+ case wf: AggregateWindowFunction =>
+ if (isDistinct) {
+ failAnalysis(s"${wf.prettyName} does not support the modifier DISTINCT")
+ } else {
+ wf
+ }
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
// This function is not an aggregate function, just return the resolved one.
- case other => other
+ case other =>
+ if (isDistinct) {
+ failAnalysis(s"${other.prettyName} does not support the modifier DISTINCT")
+ } else {
+ other
+ }
}
}
}
@@ -1283,7 +1360,7 @@ class Analyzer(
// Category 1:
// BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias
- case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>
+ case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>
// Category 2:
// These operators can be anywhere in a correlated subquery.
@@ -1449,7 +1526,7 @@ class Analyzer(
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@@ -1464,7 +1541,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
@@ -1490,7 +1567,7 @@ class Analyzer(
* underlying aggregate operator and then projected away after the original operator.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved =>
@@ -1574,7 +1651,7 @@ class Analyzer(
// to push down this ordering expression and can reference the original aggregate
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
- val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
+ val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map {
case (evaluated, order) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
@@ -1662,7 +1739,7 @@ class Analyzer(
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
@@ -1720,7 +1797,7 @@ class Analyzer(
* that wrap the [[Generator]].
*/
object ResolveGenerate extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@@ -2037,7 +2114,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p
case f: Filter => f
@@ -2082,7 +2159,7 @@ class Analyzer(
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp {
@@ -2147,7 +2224,7 @@ class Analyzer(
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
@@ -2212,7 +2289,7 @@ class Analyzer(
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
@@ -2230,8 +2307,8 @@ class Analyzer(
val result = resolved transformDown {
case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
inputData.dataType match {
- case ArrayType(et, _) =>
- val expr = MapObjects(func, inputData, et, cls) transformUp {
+ case ArrayType(et, cn) =>
+ val expr = MapObjects(func, inputData, et, cn, cls) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -2298,7 +2375,7 @@ class Analyzer(
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
@@ -2332,7 +2409,7 @@ class Analyzer(
"type of the field in the target object")
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
@@ -2347,23 +2424,6 @@ class Analyzer(
}
}
}
-
- /**
- * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
- * time zone.
- */
- object ResolveTimeZone extends Rule[LogicalPlan] {
-
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
- case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
- e.withTimeZone(conf.sessionLocalTimeZone)
- // Casts could be added in the subquery plan through the rule TypeCoercion while coercing
- // the types between the value expression and list query expression of IN expression.
- // We need to subject the subquery plan through ResolveTimeZone again to setup timezone
- // information for time zone aware expressions.
- case e: ListQuery => e.withNewPlan(apply(e.plan))
- }
- }
}
/**
@@ -2388,7 +2448,9 @@ object EliminateUnions extends Rule[LogicalPlan] {
/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
- * Window(window expressions).
+ * Window(window expressions). Notice that if an expression has other expression parameters which
+ * are not in its `children`, e.g. `RuntimeReplaceable`, the transformation for Aliases in this
+ * rule can't work for those parameters.
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
@@ -2403,7 +2465,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
case other => trimAliases(other)
}
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
@@ -2432,6 +2494,16 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}
+/**
+ * Ignore event time watermark in batch query, which is only supported in Structured Streaming.
+ * TODO: add this rule into analyzer rule list.
+ */
+object EliminateEventTimeWatermark extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case EventTimeWatermark(_, _, child) if !child.isStreaming => child
+ }
+}
+
/**
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
* figure out how many windows a time column can map to, we over-estimate the number of windows and
@@ -2471,7 +2543,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index da0c6b098f5ce..2e3ac3e474866 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -130,12 +130,13 @@ trait CheckAnalysis extends PredicateHelper {
}
case s @ ScalarSubquery(query, conditions, _) =>
+ checkAnalysis(query)
+
// If no correlation, the output must be exactly one column
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
- }
- else if (conditions.nonEmpty) {
+ } else if (conditions.nonEmpty) {
def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates containing exactly one aggregate expression.
@@ -179,7 +180,6 @@ trait CheckAnalysis extends PredicateHelper {
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
}
}
- checkAnalysis(query)
s
case s: SubqueryExpression =>
@@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper {
}
def checkValidGroupingExprs(expr: Expression): Unit = {
+ if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
+ failAnalysis(
+ "aggregate functions are not allowed in GROUP BY, but found " + expr.sql)
+ }
+
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
@@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper {
}
}
- aggregateExprs.foreach(checkValidAggregateExpression)
groupingExprs.foreach(checkValidGroupingExprs)
+ aggregateExprs.foreach(checkValidAggregateExpression)
case Sort(orders, _, _) =>
orders.foreach { order =>
@@ -394,7 +399,7 @@ trait CheckAnalysis extends PredicateHelper {
|in operator ${operator.simpleString}
""".stripMargin)
- case _: Hint =>
+ case _: UnresolvedHint =>
throw new IllegalStateException(
"Internal error: logical hint operator should have been removed during analysis")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index 9c38dd2ee4e53..a48801c5ee140 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -80,12 +80,12 @@ object DecimalPrecision extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
- case q => q.transformExpressions(
+ case q => q.transformExpressionsUp(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
}
/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
- private val decimalAndDecimal: PartialFunction[Expression, Expression] = {
+ private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e1d83a86f99dc..96b6b11582162 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import java.lang.reflect.Modifier
+
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
@@ -428,6 +430,8 @@ object FunctionRegistry {
expression[StructsToJson]("to_json"),
expression[JsonToStructs]("from_json"),
+ // cast
+ expression[Cast]("cast"),
// Cast aliases (SPARK-16730)
castAlias("boolean", BooleanType),
castAlias("tinyint", ByteType),
@@ -455,8 +459,17 @@ object FunctionRegistry {
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
+ // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
+ // constructor and contains non-parameter `child` and should not be used as function builder.
+ val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) {
+ val all = tag.runtimeClass.getConstructors
+ val maxNumArgs = all.map(_.getParameterCount).max
+ all.filterNot(_.getParameterCount == maxNumArgs)
+ } else {
+ tag.runtimeClass.getConstructors
+ }
// See if we can find a constructor that accepts Seq[Expression]
- val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption
+ val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]]))
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
@@ -470,11 +483,8 @@ object FunctionRegistry {
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
- val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match {
- case Success(e) =>
- e
- case Failure(e) =>
- throw new AnalysisException(s"Invalid number of arguments for function $name")
+ val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
+ throw new AnalysisException(s"Invalid number of arguments for function $name")
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
@@ -504,7 +514,9 @@ object FunctionRegistry {
}
Cast(args.head, dataType)
}
- (name, (expressionInfo[Cast](name), builder))
+ val clazz = scala.reflect.classTag[Cast].runtimeClass
+ val usage = "_FUNC_(expr) - Casts the value `expr` to the target data type `_FUNC_`."
+ (name, (new ExpressionInfo(clazz.getCanonicalName, null, name, usage, null), builder))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index c4827b81e8b63..62a3482d9fac1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
@@ -57,11 +58,11 @@ object ResolveHints {
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
- BroadcastHint(plan)
+ ResolvedHint(plan, HintInfo(isBroadcastable = Option(true)))
case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
- BroadcastHint(plan)
+ ResolvedHint(plan, HintInfo(isBroadcastable = Option(true)))
- case _: BroadcastHint | _: View | _: With | _: SubqueryAlias =>
+ case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
// For an existing broadcast hint, there is no point going down (if we do, we either
// won't change the structure, or will introduce another broadcast hint that is useless.
@@ -85,8 +86,19 @@ object ResolveHints {
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
- applyBroadcastHint(h.child, h.parameters.toSet)
+ case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
+ if (h.parameters.isEmpty) {
+ // If there is no table alias specified, turn the entire subtree into a BroadcastHint.
+ ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true)))
+ } else {
+ // Otherwise, find within the subtree query plans that should be broadcasted.
+ applyBroadcastHint(h.child, h.parameters.map {
+ case tableName: String => tableName
+ case tableId: UnresolvedAttribute => tableId.name
+ case unsupported => throw new AnalysisException("Broadcast hint parameter should be " +
+ s"an identifier or string but was $unsupported (${unsupported.getClass}")
+ }.toSet)
+ }
}
}
@@ -96,7 +108,7 @@ object ResolveHints {
*/
object RemoveAllHints extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case h: Hint => h.child
+ case h: UnresolvedHint => h.child
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index a991dd96e2828..f2df3e132629f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
@@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
-case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
+case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
@@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
val castedExpr = if (e.dataType.sameType(targetType)) {
e
} else {
- Cast(e, targetType)
+ cast(e, targetType)
}
- castedExpr.transform {
- case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
- e.withTimeZone(conf.sessionLocalTimeZone)
- }.eval()
+ castedExpr.eval()
} catch {
case NonFatal(ex) =>
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
index 8841309939c24..de6de24350f23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import java.util.Locale
+
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._
@@ -103,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
- builtinFunctions.get(u.functionName.toLowerCase()) match {
+ builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
val resolved = tvf.flatMap { case (argList, resolver) =>
argList.implicitCast(u.functionArgs) match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
index 256b18771052a..860d20f897690 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
@@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
case _ => false
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
val newOrders = s.order.map {
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index e1dd010d37a95..4772ab16691c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -125,6 +125,13 @@ object TypeCoercion {
case (DateType, TimestampType) => Some(StringType)
case (StringType, NullType) => Some(StringType)
case (NullType, StringType) => Some(StringType)
+
+ // There is no proper decimal type we can pick,
+ // using double type is the best we can do.
+ // See SPARK-22469 for details.
+ case (n: DecimalType, s: StringType) => Some(DoubleType)
+ case (s: StringType, n: DecimalType) => Some(DoubleType)
+
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
case (l, r) => None
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 3f76f26dbe4ec..6ab4153bac70e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -267,7 +267,7 @@ object UnsupportedOperationChecker {
throwError("Limits are not supported on streaming DataFrames/Datasets")
case Sort(_, _, _) if !containsCompleteData(subPlan) =>
- throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" +
+ throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " +
"aggregated DataFrame/Dataset in Complete output mode")
case Sample(_, _, _, _, child) if child.isStreaming =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
new file mode 100644
index 0000000000000..a27aa845bf0ae
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
+ * time zone.
+ */
+case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
+ private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
+ case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
+ e.withTimeZone(conf.sessionLocalTimeZone)
+ // Casts could be added in the subquery plan through the rule TypeCoercion while coercing
+ // the types between the value expression and list query expression of IN expression.
+ // We need to subject the subquery plan through ResolveTimeZone again to setup timezone
+ // information for time zone aware expressions.
+ case e: ListQuery => e.withNewPlan(apply(e.plan))
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ plan.resolveExpressions(transformTimeZoneExprs)
+
+ def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
+}
+
+/**
+ * Mix-in trait for constructing valid [[Cast]] expressions.
+ */
+trait CastSupport {
+ /**
+ * Configuration used to create a valid cast expression.
+ */
+ def conf: SQLConf
+
+ /**
+ * Create a Cast expression with the session local time zone.
+ */
+ def cast(child: Expression, dataType: DataType): Cast = {
+ Cast(child, dataType, Option(conf.sessionLocalTimeZone))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
index 3bd54c257d98d..ea46dd7282401 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf
* This should be only done after the batch of Resolution, because the view attributes are not
* completely resolved during the batch of Resolution.
*/
-case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
+case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver
@@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " +
s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n")
} else {
- Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
+ Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata))
}
case (_, originAttr) => originAttr
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
index 08a01e8601897..8db6f79e0f395 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ListenerBus
/**
* Interface for the system catalog (of functions, partitions, tables, and databases).
@@ -30,7 +31,8 @@ import org.apache.spark.sql.types.StructType
*
* Implementations should throw [[NoSuchDatabaseException]] when databases don't exist.
*/
-abstract class ExternalCatalog {
+abstract class ExternalCatalog
+ extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] {
import CatalogTypes.TablePartitionSpec
protected def requireDbExists(db: String): Unit = {
@@ -61,9 +63,22 @@ abstract class ExternalCatalog {
// Databases
// --------------------------------------------------------------------------
- def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit
+ final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {
+ val db = dbDefinition.name
+ postToAll(CreateDatabasePreEvent(db))
+ doCreateDatabase(dbDefinition, ignoreIfExists)
+ postToAll(CreateDatabaseEvent(db))
+ }
+
+ protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit
+
+ final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {
+ postToAll(DropDatabasePreEvent(db))
+ doDropDatabase(db, ignoreIfNotExists, cascade)
+ postToAll(DropDatabaseEvent(db))
+ }
- def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit
+ protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit
/**
* Alter a database whose name matches the one specified in `dbDefinition`,
@@ -88,11 +103,39 @@ abstract class ExternalCatalog {
// Tables
// --------------------------------------------------------------------------
- def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit
+ final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
+ val db = tableDefinition.database
+ val name = tableDefinition.identifier.table
+ postToAll(CreateTablePreEvent(db, name))
+ doCreateTable(tableDefinition, ignoreIfExists)
+ postToAll(CreateTableEvent(db, name))
+ }
+
+ protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit
+
+ final def dropTable(
+ db: String,
+ table: String,
+ ignoreIfNotExists: Boolean,
+ purge: Boolean): Unit = {
+ postToAll(DropTablePreEvent(db, table))
+ doDropTable(db, table, ignoreIfNotExists, purge)
+ postToAll(DropTableEvent(db, table))
+ }
+
+ protected def doDropTable(
+ db: String,
+ table: String,
+ ignoreIfNotExists: Boolean,
+ purge: Boolean): Unit
- def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit
+ final def renameTable(db: String, oldName: String, newName: String): Unit = {
+ postToAll(RenameTablePreEvent(db, oldName, newName))
+ doRenameTable(db, oldName, newName)
+ postToAll(RenameTableEvent(db, oldName, newName))
+ }
- def renameTable(db: String, oldName: String, newName: String): Unit
+ protected def doRenameTable(db: String, oldName: String, newName: String): Unit
/**
* Alter a table whose database and name match the ones specified in `tableDefinition`, assuming
@@ -105,22 +148,18 @@ abstract class ExternalCatalog {
def alterTable(tableDefinition: CatalogTable): Unit
/**
- * Alter the schema of a table identified by the provided database and table name. The new schema
- * should still contain the existing bucket columns and partition columns used by the table. This
- * method will also update any Spark SQL-related parameters stored as Hive table properties (such
- * as the schema itself).
+ * Alter the data schema of a table identified by the provided database and table name. The new
+ * data schema should not have conflict column names with the existing partition columns, and
+ * should still contain all the existing data columns.
*
* @param db Database that table to alter schema for exists in
* @param table Name of table to alter schema for
- * @param schema Updated schema to be used for the table (must contain existing partition and
- * bucket columns)
+ * @param newDataSchema Updated data schema to be used for the table.
*/
- def alterTableSchema(db: String, table: String, schema: StructType): Unit
+ def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit
def getTable(db: String, table: String): CatalogTable
- def getTableOption(db: String, table: String): Option[CatalogTable]
-
def tableExists(db: String, table: String): Boolean
def listTables(db: String): Seq[String]
@@ -269,11 +308,30 @@ abstract class ExternalCatalog {
// Functions
// --------------------------------------------------------------------------
- def createFunction(db: String, funcDefinition: CatalogFunction): Unit
+ final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = {
+ val name = funcDefinition.identifier.funcName
+ postToAll(CreateFunctionPreEvent(db, name))
+ doCreateFunction(db, funcDefinition)
+ postToAll(CreateFunctionEvent(db, name))
+ }
+
+ protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit
+
+ final def dropFunction(db: String, funcName: String): Unit = {
+ postToAll(DropFunctionPreEvent(db, funcName))
+ doDropFunction(db, funcName)
+ postToAll(DropFunctionEvent(db, funcName))
+ }
+
+ protected def doDropFunction(db: String, funcName: String): Unit
- def dropFunction(db: String, funcName: String): Unit
+ final def renameFunction(db: String, oldName: String, newName: String): Unit = {
+ postToAll(RenameFunctionPreEvent(db, oldName, newName))
+ doRenameFunction(db, oldName, newName)
+ postToAll(RenameFunctionEvent(db, oldName, newName))
+ }
- def renameFunction(db: String, oldName: String, newName: String): Unit
+ protected def doRenameFunction(db: String, oldName: String, newName: String): Unit
def getFunction(db: String, funcName: String): CatalogFunction
@@ -281,4 +339,9 @@ abstract class ExternalCatalog {
def listFunctions(db: String, pattern: String): Seq[String]
+ override protected def doPostEvent(
+ listener: ExternalCatalogEventListener,
+ event: ExternalCatalogEvent): Unit = {
+ listener.onEvent(event)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
index 3ca9e6a8da5b5..50f32e81d997d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
@@ -155,10 +155,22 @@ object ExternalCatalogUtils {
})
inputPartitions.filter { p =>
- boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
+ boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId))
}
}
}
+
+ /**
+ * Returns true if `spec1` is a partial partition spec w.r.t. `spec2`, e.g. PARTITION (a=1) is a
+ * partial partition spec w.r.t. PARTITION (a=1,b=2).
+ */
+ def isPartialPartitionSpec(
+ spec1: TablePartitionSpec,
+ spec2: TablePartitionSpec): Boolean = {
+ spec1.forall {
+ case (partitionColumn, value) => spec2(partitionColumn) == value
+ }
+ }
}
object CatalogUtils {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index 9ca1c71d1dcb1..f83e28f5d046f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -98,7 +98,7 @@ class InMemoryCatalog(
// Databases
// --------------------------------------------------------------------------
- override def createDatabase(
+ override protected def doCreateDatabase(
dbDefinition: CatalogDatabase,
ignoreIfExists: Boolean): Unit = synchronized {
if (catalog.contains(dbDefinition.name)) {
@@ -119,7 +119,7 @@ class InMemoryCatalog(
}
}
- override def dropDatabase(
+ override protected def doDropDatabase(
db: String,
ignoreIfNotExists: Boolean,
cascade: Boolean): Unit = synchronized {
@@ -180,7 +180,7 @@ class InMemoryCatalog(
// Tables
// --------------------------------------------------------------------------
- override def createTable(
+ override protected def doCreateTable(
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = synchronized {
assert(tableDefinition.identifier.database.isDefined)
@@ -221,7 +221,7 @@ class InMemoryCatalog(
}
}
- override def dropTable(
+ override protected def doDropTable(
db: String,
table: String,
ignoreIfNotExists: Boolean,
@@ -264,7 +264,10 @@ class InMemoryCatalog(
}
}
- override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized {
+ override protected def doRenameTable(
+ db: String,
+ oldName: String,
+ newName: String): Unit = synchronized {
requireTableExists(db, oldName)
requireTableNotExists(db, newName)
val oldDesc = catalog(db).tables(oldName)
@@ -298,13 +301,14 @@ class InMemoryCatalog(
catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition
}
- override def alterTableSchema(
+ override def alterTableDataSchema(
db: String,
table: String,
- schema: StructType): Unit = synchronized {
+ newDataSchema: StructType): Unit = synchronized {
requireTableExists(db, table)
val origTable = catalog(db).tables(table).table
- catalog(db).tables(table).table = origTable.copy(schema = schema)
+ val newSchema = StructType(newDataSchema ++ origTable.partitionSchema)
+ catalog(db).tables(table).table = origTable.copy(schema = newSchema)
}
override def getTable(db: String, table: String): CatalogTable = synchronized {
@@ -312,10 +316,6 @@ class InMemoryCatalog(
catalog(db).tables(table).table
}
- override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized {
- if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table)
- }
-
override def tableExists(db: String, table: String): Boolean = synchronized {
requireDbExists(db)
catalog(db).tables.contains(table)
@@ -539,18 +539,6 @@ class InMemoryCatalog(
}
}
- /**
- * Returns true if `spec1` is a partial partition spec w.r.t. `spec2`, e.g. PARTITION (a=1) is a
- * partial partition spec w.r.t. PARTITION (a=1,b=2).
- */
- private def isPartialPartitionSpec(
- spec1: TablePartitionSpec,
- spec2: TablePartitionSpec): Boolean = {
- spec1.forall {
- case (partitionColumn, value) => spec2(partitionColumn) == value
- }
- }
-
override def listPartitionsByFilter(
db: String,
table: String,
@@ -565,18 +553,21 @@ class InMemoryCatalog(
// Functions
// --------------------------------------------------------------------------
- override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
+ override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized {
requireDbExists(db)
requireFunctionNotExists(db, func.identifier.funcName)
catalog(db).functions.put(func.identifier.funcName, func)
}
- override def dropFunction(db: String, funcName: String): Unit = synchronized {
+ override protected def doDropFunction(db: String, funcName: String): Unit = synchronized {
requireFunctionExists(db, funcName)
catalog(db).functions.remove(funcName)
}
- override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized {
+ override protected def doRenameFunction(
+ db: String,
+ oldName: String,
+ newName: String): Unit = synchronized {
requireFunctionExists(db, oldName)
requireFunctionNotExists(db, newName)
val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 3fbf83f3a38a2..bbcfdace731ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.catalog
import java.net.URI
import java.util.Locale
+import java.util.concurrent.Callable
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
@@ -73,7 +74,7 @@ class SessionCatalog(
functionRegistry,
conf,
new Configuration(),
- CatalystSqlParser,
+ new CatalystSqlParser(conf),
DummyFunctionResourceLoader)
}
@@ -115,24 +116,46 @@ class SessionCatalog(
* Format table name, taking into account case sensitivity.
*/
protected[this] def formatTableName(name: String): String = {
- if (conf.caseSensitiveAnalysis) name else name.toLowerCase
+ if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}
/**
* Format database name, taking into account case sensitivity.
*/
protected[this] def formatDatabaseName(name: String): String = {
- if (conf.caseSensitiveAnalysis) name else name.toLowerCase
+ if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}
- /**
- * A cache of qualified table names to table relation plans.
- */
- val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = {
+ private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = {
val cacheSize = conf.tableRelationCacheSize
CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]()
}
+ /** This method provides a way to get a cached plan. */
+ def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = {
+ tableRelationCache.get(t, c)
+ }
+
+ /** This method provides a way to get a cached plan if the key exists. */
+ def getCachedTable(key: QualifiedTableName): LogicalPlan = {
+ tableRelationCache.getIfPresent(key)
+ }
+
+ /** This method provides a way to cache a plan. */
+ def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = {
+ tableRelationCache.put(t, l)
+ }
+
+ /** This method provides a way to invalidate a cached plan. */
+ def invalidateCachedTable(key: QualifiedTableName): Unit = {
+ tableRelationCache.invalidate(key)
+ }
+
+ /** This method provides a way to invalidate all the cached plans. */
+ def invalidateAllCachedTables(): Unit = {
+ tableRelationCache.invalidateAll()
+ }
+
/**
* This method is used to make the given path qualified before we
* store this path in the underlying external catalog. So, when a path
@@ -313,30 +336,28 @@ class SessionCatalog(
}
/**
- * Alter the schema of a table identified by the provided table identifier. The new schema
- * should still contain the existing bucket columns and partition columns used by the table. This
- * method will also update any Spark SQL-related parameters stored as Hive table properties (such
- * as the schema itself).
+ * Alter the data schema of a table identified by the provided table identifier. The new data
+ * schema should not have conflict column names with the existing partition columns, and should
+ * still contain all the existing data columns.
*
* @param identifier TableIdentifier
- * @param newSchema Updated schema to be used for the table (must contain existing partition and
- * bucket columns, and partition columns need to be at the end)
+ * @param newDataSchema Updated data schema to be used for the table
*/
- def alterTableSchema(
+ def alterTableDataSchema(
identifier: TableIdentifier,
- newSchema: StructType): Unit = {
+ newDataSchema: StructType): Unit = {
val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase))
val table = formatTableName(identifier.table)
val tableIdentifier = TableIdentifier(table, Some(db))
requireDbExists(db)
requireTableExists(tableIdentifier)
- checkDuplication(newSchema)
val catalogTable = externalCatalog.getTable(db, table)
- val oldSchema = catalogTable.schema
-
+ checkDuplication(newDataSchema ++ catalogTable.partitionSchema)
+ val oldDataSchema = catalogTable.dataSchema
// not supporting dropping columns yet
- val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _))
+ val nonExistentColumnNames =
+ oldDataSchema.map(_.name).filterNot(columnNameResolved(newDataSchema, _))
if (nonExistentColumnNames.nonEmpty) {
throw new AnalysisException(
s"""
@@ -345,8 +366,7 @@ class SessionCatalog(
""".stripMargin)
}
- // assuming the newSchema has all partition columns at the end as required
- externalCatalog.alterTableSchema(db, table, newSchema)
+ externalCatalog.alterTableDataSchema(db, table, newDataSchema)
}
private def columnNameResolved(schema: StructType, colName: String): Boolean = {
@@ -365,9 +385,10 @@ class SessionCatalog(
/**
* Retrieve the metadata of an existing permanent table/view. If no database is specified,
- * assume the table/view is in the current database. If the specified table/view is not found
- * in the database then a [[NoSuchTableException]] is thrown.
+ * assume the table/view is in the current database.
*/
+ @throws[NoSuchDatabaseException]
+ @throws[NoSuchTableException]
def getTableMetadata(name: TableIdentifier): CatalogTable = {
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
val table = formatTableName(name.table)
@@ -376,18 +397,6 @@ class SessionCatalog(
externalCatalog.getTable(db, table)
}
- /**
- * Retrieve the metadata of an existing metastore table.
- * If no database is specified, assume the table is in the current database.
- * If the specified table is not found in the database then return None if it doesn't exist.
- */
- def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
- val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
- val table = formatTableName(name.table)
- requireDbExists(db)
- externalCatalog.getTableOption(db, table)
- }
-
/**
* Load files stored in given path into an existing metastore table.
* If no database is specified, assume the table is in the current database.
@@ -667,12 +676,7 @@ class SessionCatalog(
child = parser.parsePlan(viewText))
SubqueryAlias(table, child)
} else {
- val tableRelation = CatalogRelation(
- metadata,
- // we assume all the columns are nullable.
- metadata.dataSchema.asNullable.toAttributes,
- metadata.partitionSchema.asNullable.toAttributes)
- SubqueryAlias(table, tableRelation)
+ SubqueryAlias(table, UnresolvedCatalogRelation(metadata))
}
} else {
SubqueryAlias(table, tempTables(table))
@@ -1105,8 +1109,9 @@ class SessionCatalog(
!hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT))
}
- protected def failFunctionLookup(name: String): Nothing = {
- throw new NoSuchFunctionException(db = currentDb, func = name)
+ protected def failFunctionLookup(name: FunctionIdentifier): Nothing = {
+ throw new NoSuchFunctionException(
+ db = name.database.getOrElse(getCurrentDatabase), func = name.funcName)
}
/**
@@ -1128,7 +1133,7 @@ class SessionCatalog(
qualifiedName.database.orNull,
qualifiedName.identifier)
} else {
- failFunctionLookup(name.funcName)
+ failFunctionLookup(name)
}
}
}
@@ -1158,8 +1163,8 @@ class SessionCatalog(
}
// If the name itself is not qualified, add the current database to it.
- val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
- val qualifiedName = name.copy(database = database)
+ val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
+ val qualifiedName = name.copy(database = Some(database))
if (functionRegistry.functionExists(qualifiedName.unquotedString)) {
// This function has been already loaded into the function registry.
@@ -1172,10 +1177,10 @@ class SessionCatalog(
// in the metastore). We need to first put the function in the FunctionRegistry.
// TODO: why not just check whether the function exists first?
val catalogFunction = try {
- externalCatalog.getFunction(currentDb, name.funcName)
+ externalCatalog.getFunction(database, name.funcName)
} catch {
- case e: AnalysisException => failFunctionLookup(name.funcName)
- case e: NoSuchPermanentFunctionException => failFunctionLookup(name.funcName)
+ case _: AnalysisException => failFunctionLookup(name)
+ case _: NoSuchPermanentFunctionException => failFunctionLookup(name)
}
loadFunctionResources(catalogFunction.resources)
// Please note that qualifiedName is provided by the user. However,
@@ -1251,9 +1256,10 @@ class SessionCatalog(
dropTempFunction(func.funcName, ignoreIfNotExists = false)
}
}
- tempTables.clear()
+ clearTempTables()
globalTempViewManager.clear()
functionRegistry.clear()
+ tableRelationCache.invalidateAll()
// restore built-in functions
FunctionRegistry.builtin.listFunction().foreach { f =>
val expressionInfo = FunctionRegistry.builtin.lookupFunction(f)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala
new file mode 100644
index 0000000000000..459973a13bb10
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.catalog
+
+import org.apache.spark.scheduler.SparkListenerEvent
+
+/**
+ * Event emitted by the external catalog when it is modified. Events are either fired before or
+ * after the modification (the event should document this).
+ */
+trait ExternalCatalogEvent extends SparkListenerEvent
+
+/**
+ * Listener interface for external catalog modification events.
+ */
+trait ExternalCatalogEventListener {
+ def onEvent(event: ExternalCatalogEvent): Unit
+}
+
+/**
+ * Event fired when a database is create or dropped.
+ */
+trait DatabaseEvent extends ExternalCatalogEvent {
+ /**
+ * Database of the object that was touched.
+ */
+ val database: String
+}
+
+/**
+ * Event fired before a database is created.
+ */
+case class CreateDatabasePreEvent(database: String) extends DatabaseEvent
+
+/**
+ * Event fired after a database has been created.
+ */
+case class CreateDatabaseEvent(database: String) extends DatabaseEvent
+
+/**
+ * Event fired before a database is dropped.
+ */
+case class DropDatabasePreEvent(database: String) extends DatabaseEvent
+
+/**
+ * Event fired after a database has been dropped.
+ */
+case class DropDatabaseEvent(database: String) extends DatabaseEvent
+
+/**
+ * Event fired when a table is created, dropped or renamed.
+ */
+trait TableEvent extends DatabaseEvent {
+ /**
+ * Name of the table that was touched.
+ */
+ val name: String
+}
+
+/**
+ * Event fired before a table is created.
+ */
+case class CreateTablePreEvent(database: String, name: String) extends TableEvent
+
+/**
+ * Event fired after a table has been created.
+ */
+case class CreateTableEvent(database: String, name: String) extends TableEvent
+
+/**
+ * Event fired before a table is dropped.
+ */
+case class DropTablePreEvent(database: String, name: String) extends TableEvent
+
+/**
+ * Event fired after a table has been dropped.
+ */
+case class DropTableEvent(database: String, name: String) extends TableEvent
+
+/**
+ * Event fired before a table is renamed.
+ */
+case class RenameTablePreEvent(
+ database: String,
+ name: String,
+ newName: String)
+ extends TableEvent
+
+/**
+ * Event fired after a table has been renamed.
+ */
+case class RenameTableEvent(
+ database: String,
+ name: String,
+ newName: String)
+ extends TableEvent
+
+/**
+ * Event fired when a function is created, dropped or renamed.
+ */
+trait FunctionEvent extends DatabaseEvent {
+ /**
+ * Name of the function that was touched.
+ */
+ val name: String
+}
+
+/**
+ * Event fired before a function is created.
+ */
+case class CreateFunctionPreEvent(database: String, name: String) extends FunctionEvent
+
+/**
+ * Event fired after a function has been created.
+ */
+case class CreateFunctionEvent(database: String, name: String) extends FunctionEvent
+
+/**
+ * Event fired before a function is dropped.
+ */
+case class DropFunctionPreEvent(database: String, name: String) extends FunctionEvent
+
+/**
+ * Event fired after a function has been dropped.
+ */
+case class DropFunctionEvent(database: String, name: String) extends FunctionEvent
+
+/**
+ * Event fired before a function is renamed.
+ */
+case class RenameFunctionPreEvent(
+ database: String,
+ name: String,
+ newName: String)
+ extends FunctionEvent
+
+/**
+ * Event fired after a function has been renamed.
+ */
+case class RenameFunctionEvent(
+ database: String,
+ name: String,
+ newName: String)
+ extends FunctionEvent
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index cc0cbba275b81..5c8e5709a34f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -75,7 +75,7 @@ case class CatalogStorageFormat(
CatalogUtils.maskCredentials(properties) match {
case props if props.isEmpty => // No-op
case props =>
- map.put("Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]"))
+ map.put("Storage Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]"))
}
map
}
@@ -313,7 +313,7 @@ case class CatalogTable(
}
}
- if (properties.nonEmpty) map.put("Properties", tableProperties)
+ if (properties.nonEmpty) map.put("Table Properties", tableProperties)
stats.foreach(s => map.put("Statistics", s.simpleString))
map ++= storage.toLinkedHashMap
if (tracksPartitionsInCatalog) map.put("Partition Provider", "Catalog")
@@ -397,11 +397,22 @@ object CatalogTypes {
type TablePartitionSpec = Map[String, String]
}
+/**
+ * A placeholder for a table relation, which will be replaced by concrete relation like
+ * `LogicalRelation` or `HiveTableRelation`, during analysis.
+ */
+case class UnresolvedCatalogRelation(tableMeta: CatalogTable) extends LeafNode {
+ assert(tableMeta.identifier.database.isDefined)
+ override lazy val resolved: Boolean = false
+ override def output: Seq[Attribute] = Nil
+}
/**
- * A [[LogicalPlan]] that represents a table.
+ * A `LogicalPlan` that represents a hive table.
+ *
+ * TODO: remove this after we completely make hive as a data source.
*/
-case class CatalogRelation(
+case class HiveTableRelation(
tableMeta: CatalogTable,
dataCols: Seq[AttributeReference],
partitionCols: Seq[AttributeReference]) extends LeafNode with MultiInstanceRelation {
@@ -415,7 +426,7 @@ case class CatalogRelation(
def isPartitioned: Boolean = partitionCols.nonEmpty
override def equals(relation: Any): Boolean = relation match {
- case other: CatalogRelation => tableMeta == other.tableMeta && output == other.output
+ case other: HiveTableRelation => tableMeta == other.tableMeta && output == other.output
case _ => false
}
@@ -434,15 +445,12 @@ case class CatalogRelation(
))
override def computeStats(conf: SQLConf): Statistics = {
- // For data source tables, we will create a `LogicalRelation` and won't call this method, for
- // hive serde tables, we will always generate a statistics.
- // TODO: unify the table stats generation.
tableMeta.stats.map(_.toPlanStats(output)).getOrElse {
throw new IllegalStateException("table stats must be specified.")
}
}
- override def newInstance(): LogicalPlan = copy(
+ override def newInstance(): HiveTableRelation = copy(
dataCols = dataCols.map(_.newInstance()),
partitionCols = partitionCols.map(_.newInstance()))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 75bf780d41424..85d17afe20230 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -168,6 +168,7 @@ package object dsl {
case Seq() => UnresolvedStar(None)
case target => UnresolvedStar(Option(target))
}
+ def namedStruct(e: Expression*): Expression = CreateNamedStruct(e)
def callFunction[T, U](
func: T => U,
@@ -366,7 +367,7 @@ package object dsl {
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(TableIdentifier(tableName)),
- Map.empty, logicalPlan, overwrite, false)
+ Map.empty, logicalPlan, overwrite, ifPartitionNotExists = false)
def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
@@ -381,6 +382,9 @@ package object dsl {
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
+
+ def hint(name: String, parameters: Any*): LogicalPlan =
+ UnresolvedHint(name, parameters, logicalPlan)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index bb1273f5c3d84..5380e6b0e8954 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -89,6 +89,31 @@ object Cast {
case _ => false
}
+ /**
+ * Return true if we need to use the `timeZone` information casting `from` type to `to` type.
+ * The patterns matched reflect the current implementation in the Cast node.
+ * c.f. usage of `timeZone` in:
+ * * Cast.castToString
+ * * Cast.castToDate
+ * * Cast.castToTimestamp
+ */
+ def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match {
+ case (StringType, TimestampType) => true
+ case (DateType, TimestampType) => true
+ case (TimestampType, StringType) => true
+ case (TimestampType, DateType) => true
+ case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType)
+ case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
+ needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue)
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields).exists {
+ case (fromField, toField) =>
+ needsTimeZone(fromField.dataType, toField.dataType)
+ }
+ case _ => false
+ }
+
/**
* Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int,
* timestamp -> date.
@@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
+ // When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
+ // Otherwise behave like Expression.resolved.
+ override lazy val resolved: Boolean =
+ childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)
+
+ private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
+
// [[func]] assumes the input is no longer null because eval already does the null check.
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
@@ -355,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
/**
* Create new `Decimal` with precision and scale given in `decimalType` (if any),
* returning null if it overflows or creating a new `value` and returning it if successful.
- *
*/
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
- value.toPrecision(decimalType.precision, decimalType.scale).orNull
+ value.toPrecision(decimalType.precision, decimalType.scale)
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
@@ -450,15 +481,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
// TODO: Could be faster?
- val newRow = new GenericInternalRow(from.fields.length)
buildCast[InternalRow](_, row => {
+ val newRow = new GenericInternalRow(from.fields.length)
var i = 0
while (i < row.numFields) {
newRow.update(i,
if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType)))
i += 1
}
- newRow.copy()
+ newRow
})
}
@@ -1008,13 +1039,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
}
}
"""
- }.mkString("\n")
+ }
+ val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ ctx.splitExpressions(
+ expressions = fieldsEvalCode,
+ funcName = "castStruct",
+ arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil)
+ } else {
+ fieldsEvalCode.mkString("\n")
+ }
(c, evPrim, evNull) =>
s"""
final $rowClass $result = new $rowClass(${fieldsCasts.length});
final InternalRow $tmpRow = $c;
- $fieldsEvalCode
+ $fieldsEvalCodes
$evPrim = $result.copy();
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b847ef7bfaa97..74c4cddf2b47e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -241,6 +241,10 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
override def nullable: Boolean = child.nullable
override def foldable: Boolean = child.foldable
override def dataType: DataType = child.dataType
+ // As this expression gets replaced at optimization with its `child" expression,
+ // two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions
+ // are semantically equal.
+ override lazy val canonicalized: Expression = child.canonicalized
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index c423e17169e85..f80df75ac7f72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -77,10 +77,10 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
// If all input are nulls, count will be 0 and we will get null after the division.
override lazy val evaluateExpression = child.dataType match {
- case DecimalType.Fixed(p, s) =>
- // increase the precision and scale to prevent precision loss
- val dt = DecimalType.bounded(p + 14, s + 4)
- Cast(Cast(sum, dt) / Cast(count, dt), resultType)
+ case _: DecimalType =>
+ Cast(
+ DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)),
+ resultType)
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 80c25d0b0fb7a..fffcc7c9ef53a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -105,12 +105,22 @@ case class AggregateExpression(
}
// We compute the same thing regardless of our final result.
- override lazy val canonicalized: Expression =
+ override lazy val canonicalized: Expression = {
+ val normalizedAggFunc = mode match {
+ // For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
+ // and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
+ case PartialMerge | Final => aggregateFunction.transform {
+ case a: AttributeReference => a.withExprId(ExprId(0))
+ }
+ case Partial | Complete => aggregateFunction
+ }
+
AggregateExpression(
- aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+ normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
mode,
isDistinct,
ExprId(0))
+ }
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index f2b252259b89d..6059ca3079dd8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -550,8 +550,8 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- val first = evalChildren(0)
- val rest = evalChildren.drop(1)
+ ctx.addMutableState("boolean", ev.isNull, "")
+ ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
def updateEval(eval: ExprCode): String = {
s"""
${eval.code}
@@ -562,11 +562,11 @@ case class Least(children: Seq[Expression]) extends Expression {
}
"""
}
+ val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval))
ev.copy(code = s"""
- ${first.code}
- boolean ${ev.isNull} = ${first.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${first.value};
- ${rest.map(updateEval).mkString("\n")}""")
+ ${ev.isNull} = true;
+ ${ev.value} = ${ctx.defaultValue(dataType)};
+ $codes""")
}
}
@@ -615,8 +615,8 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- val first = evalChildren(0)
- val rest = evalChildren.drop(1)
+ ctx.addMutableState("boolean", ev.isNull, "")
+ ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
def updateEval(eval: ExprCode): String = {
s"""
${eval.code}
@@ -627,10 +627,10 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
"""
}
+ val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval))
ev.copy(code = s"""
- ${first.code}
- boolean ${ev.isNull} = ${first.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${first.value};
- ${rest.map(updateEval).mkString("\n")}""")
+ ${ev.isNull} = true;
+ ${ev.value} = ${ctx.defaultValue(dataType)};
+ $codes""")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 760ead42c762c..9e5eaf6ff25b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -27,7 +27,10 @@ import scala.language.existentials
import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, CacheLoader}
-import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
+import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
+import org.apache.commons.lang3.exception.ExceptionUtils
+import org.codehaus.commons.compiler.CompileException
+import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, InternalCompilerException, SimpleCompiler}
import org.codehaus.janino.util.ClassFile
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
@@ -479,8 +482,10 @@ class CodegenContext {
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
- case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
- case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
+ case FloatType =>
+ s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)"
+ case DoubleType =>
+ s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
@@ -657,20 +662,7 @@ class CodegenContext {
returnType: String = "void",
makeSplitFunction: String => String = identity,
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
- val blocks = new ArrayBuffer[String]()
- val blockBuilder = new StringBuilder()
- for (code <- expressions) {
- // We can't know how many bytecode will be generated, so use the length of source code
- // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
- // also not be too small, or it will have many function calls (for wide table), see the
- // results in BenchmarkWideTable.
- if (blockBuilder.length > 1024) {
- blocks += blockBuilder.toString()
- blockBuilder.clear()
- }
- blockBuilder.append(code)
- }
- blocks += blockBuilder.toString()
+ val blocks = buildCodeBlocks(expressions)
if (blocks.length == 1) {
// inline execution if only one block
@@ -693,6 +685,59 @@ class CodegenContext {
}
}
+ /**
+ * Splits the generated code of expressions into multiple sequences of String
+ * based on a threshold of length of a String
+ *
+ * @param expressions the codes to evaluate expressions.
+ */
+ def buildCodeBlocks(expressions: Seq[String]): Seq[String] = {
+ val blocks = new ArrayBuffer[String]()
+ val blockBuilder = new StringBuilder()
+ for (code <- expressions) {
+ // We can't know how many bytecode will be generated, so use the length of source code
+ // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
+ // also not be too small, or it will have many function calls (for wide table), see the
+ // results in BenchmarkWideTable.
+ if (blockBuilder.length > 1024) {
+ blocks += blockBuilder.toString()
+ blockBuilder.clear()
+ }
+ blockBuilder.append(code)
+ }
+ blocks += blockBuilder.toString()
+ }
+
+ /**
+ * Wrap the generated code of expression, which was created from a row object in INPUT_ROW,
+ * by a function. ev.isNull and ev.value are passed by global variables
+ *
+ * @param ev the code to evaluate expressions.
+ * @param dataType the data type of ev.value.
+ * @param baseFuncName the split function name base.
+ */
+ def createAndAddFunction(
+ ev: ExprCode,
+ dataType: DataType,
+ baseFuncName: String): (String, String, String) = {
+ val globalIsNull = freshName("isNull")
+ addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
+ val globalValue = freshName("value")
+ addMutableState(javaType(dataType), globalValue,
+ s"$globalValue = ${defaultValue(dataType)};")
+ val funcName = freshName(baseFuncName)
+ val funcBody =
+ s"""
+ |private void $funcName(InternalRow ${INPUT_ROW}) {
+ | ${ev.code.trim}
+ | $globalIsNull = ${ev.isNull};
+ | $globalValue = ${ev.value};
+ |}
+ """.stripMargin
+ addNewFunction(funcName, funcBody)
+ (funcName, globalIsNull, globalValue)
+ }
+
/**
* Perform a function which generates a sequence of ExprCodes with a given mapping between
* expressions and common expressions, instead of using the mapping in current context.
@@ -899,8 +944,14 @@ object CodeGenerator extends Logging {
/**
* Compile the Java source code into a Java class, using Janino.
*/
- def compile(code: CodeAndComment): GeneratedClass = {
+ def compile(code: CodeAndComment): GeneratedClass = try {
cache.get(code)
+ } catch {
+ // Cache.get() may wrap the original exception. See the following URL
+ // http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/
+ // Cache.html#get(K,%20java.util.concurrent.Callable)
+ case e @ (_: UncheckedExecutionException | _: ExecutionError) =>
+ throw e.getCause
}
/**
@@ -951,10 +1002,14 @@ object CodeGenerator extends Logging {
evaluator.cook("generated.java", code.body)
recordCompilationStats(evaluator)
} catch {
- case e: Exception =>
+ case e: InternalCompilerException =>
+ val msg = s"failed to compile: $e\n$formatted"
+ logError(msg, e)
+ throw new InternalCompilerException(msg, e)
+ case e: CompileException =>
val msg = s"failed to compile: $e\n$formatted"
logError(msg, e)
- throw new Exception(msg, e)
+ throw new CompileException(msg, e.getLocation)
}
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index f7fc2d54a047b..a2fe55bfef7a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -72,13 +72,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
* Generates the code for ordering based on the given order.
*/
def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = {
+ val oldInputRow = ctx.INPUT_ROW
+ val oldCurrentVars = ctx.currentVars
+ val inputRow = "i"
+ ctx.INPUT_ROW = inputRow
+ // to use INPUT_ROW we must make sure currentVars is null
+ ctx.currentVars = null
+
val comparisons = ordering.map { order =>
- val oldCurrentVars = ctx.currentVars
- ctx.INPUT_ROW = "i"
- // to use INPUT_ROW we must make sure currentVars is null
- ctx.currentVars = null
val eval = order.child.genCode(ctx)
- ctx.currentVars = oldCurrentVars
val asc = order.isAscending
val isNullA = ctx.freshName("isNullA")
val primitiveA = ctx.freshName("primitiveA")
@@ -147,10 +149,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
"""
}.mkString
})
+ ctx.currentVars = oldCurrentVars
+ ctx.INPUT_ROW = oldInputRow
// make sure INPUT_ROW is declared even if splitExpressions
// returns an inlined block
s"""
- |InternalRow ${ctx.INPUT_ROW} = null;
+ |InternalRow $inputRow = null;
|$code
""".stripMargin
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7e4c9089a2cb9..b358102d914bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
fieldTypes: Seq[DataType],
bufferHolder: String): String = {
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
- val fieldName = ctx.freshName("fieldName")
- val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
- val isNull = s"$input.isNullAt($i)"
- ExprCode(code, isNull, fieldName)
+ val javaType = ctx.javaType(dt)
+ val isNullVar = ctx.freshName("isNull")
+ val valueVar = ctx.freshName("value")
+ val defaultValue = ctx.defaultValue(dt)
+ val readValue = ctx.getValue(input, dt, i.toString)
+ val code =
+ s"""
+ boolean $isNullVar = $input.isNullAt($i);
+ $javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
+ """
+ ExprCode(code, isNullVar, valueVar)
}
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 4aa5ec82471ec..21df42b2b423a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.codegen
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.Platform
@@ -51,6 +54,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
}
def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
+ val ctx = new CodegenContext
val offset = Platform.BYTE_ARRAY_OFFSET
val getLong = "Platform.getLong"
val putLong = "Platform.putLong"
@@ -66,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
// --------------------- copy bitset from row 1 and row 2 --------------------------- //
val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
- val bits = if (bitset1Remainder > 0) {
+ val bits = if (bitset1Remainder > 0 && bitset2Words != 0) {
if (i < bitset1Words - 1) {
s"$getLong(obj1, offset1 + ${i * 8})"
} else if (i == bitset1Words - 1) {
@@ -88,8 +92,14 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
}
}
- s"$putLong(buf, ${offset + i * 8}, $bits);"
- }.mkString("\n")
+ s"$putLong(buf, ${offset + i * 8}, $bits);\n"
+ }
+
+ val copyBitsets = ctx.splitExpressions(
+ expressions = copyBitset,
+ funcName = "copyBitsetFunc",
+ arguments = ("java.lang.Object", "obj1") :: ("long", "offset1") ::
+ ("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil)
// --------------------- copy fixed length portion from row 1 ----------------------- //
var cursor = offset + outputBitsetWords * 8
@@ -142,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
} else {
// Number of bytes to increase for the offset. Note that since in UnsafeRow we store the
// offset in the upper 32 bit of the words, we can just shift the offset to the left by
- // 32 and increment that amount in place.
+ // 32 and increment that amount in place. However, we need to handle the important special
+ // case of a null field, in which case the offset should be zero and should not have a
+ // shift added to it.
val shift =
if (i < schema1.size) {
s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
@@ -150,11 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
+ // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
+ // output as a de-facto specification for the internal layout of data.
+ //
+ // Null-valued fields will always have a data offset of 0 because
+ // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's
+ // position in the fixed-length section of the row. As a result, we must NOT add
+ // `shift` to the offset for null fields.
+ //
+ // We could perform a null-check here by inspecting the null-tracking bitmap, but doing
+ // so could be expensive and will add significant bloat to the generated code. Instead,
+ // we'll rely on the invariant "stored offset == 0 for variable-length data type implies
+ // that the field's value is null."
+ //
+ // To establish that this invariant holds, we'll prove that a non-null field can never
+ // have a stored offset of 0. There are two cases to consider:
+ //
+ // 1. The non-null field's data is of non-zero length: reading this field's value
+ // must read data from the variable-length section of the row, so the stored offset
+ // will actually be used in address calculation and must be correct. The offsets
+ // count bytes from the start of the UnsafeRow so these offsets will always be
+ // non-zero because the storage of the offsets themselves takes up space at the
+ // start of the row.
+ // 2. The non-null field's data is of zero length (i.e. its data is empty). In this
+ // case, we have to worry about the possibility that an arbitrary offset value was
+ // stored because we never actually read any bytes using this offset and therefore
+ // would not crash if it was incorrect. The variable-sized data writing paths in
+ // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with
+ // no special handling for the case where `numBytes == 0`. Internally,
+ // setOffsetAndSize computes the offset without taking the size into account. Thus
+ // the stored offset is the same non-zero offset that would be used if the field's
+ // dataSize was non-zero (and in (1) above we've shown that case behaves as we
+ // expect).
+ //
+ // Thus it is safe to perform `existingOffset != 0` checks here in the place of
+ // more expensive null-bit checks.
s"""
- |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
+ |existingOffset = $getLong(buf, $cursor);
+ |if (existingOffset != 0) {
+ | $putLong(buf, $cursor, existingOffset + ($shift << 32));
+ |}
""".stripMargin
}
- }.mkString("\n")
+ }
+
+ val updateOffsets = ctx.splitExpressions(
+ expressions = updateOffset,
+ funcName = "copyBitsetFunc",
+ arguments = ("long", "numBytesVariableRow1") :: Nil,
+ makeSplitFunction = (s: String) => "long existingOffset;\n" + s)
// ------------------------ Finally, put everything together --------------------------- //
val codeBody = s"""
@@ -166,6 +222,8 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| private byte[] buf = new byte[64];
| private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size});
|
+ | ${ctx.declareAddedFunctions()}
+ |
| public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) {
| // row1: ${schema1.size} fields, $bitset1Words words in bitset
| // row2: ${schema2.size}, $bitset2Words words in bitset
@@ -180,12 +238,13 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| final java.lang.Object obj2 = row2.getBaseObject();
| final long offset2 = row2.getBaseOffset();
|
- | $copyBitset
+ | $copyBitsets
| $copyFixedLengthRow1
| $copyFixedLengthRow2
| $copyVariableLengthRow1
| $copyVariableLengthRow2
- | $updateOffset
+ | long existingOffset;
+ | $updateOffsets
|
| out.pointTo(buf, sizeInBytes);
|
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index ee365fe636614..0830b6a0fd0bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -72,11 +72,11 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {
val (condFuncName, condGlobalIsNull, condGlobalValue) =
- createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr")
+ ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr")
val (trueFuncName, trueGlobalIsNull, trueGlobalValue) =
- createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr")
+ ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr")
val (falseFuncName, falseGlobalIsNull, falseGlobalValue) =
- createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr")
+ ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr")
s"""
$condFuncName(${ctx.INPUT_ROW});
boolean ${ev.isNull} = false;
@@ -112,29 +112,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
ev.copy(code = generatedCode)
}
- private def createAndAddFunction(
- ctx: CodegenContext,
- ev: ExprCode,
- dataType: DataType,
- baseFuncName: String): (String, String, String) = {
- val globalIsNull = ctx.freshName("isNull")
- ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
- val globalValue = ctx.freshName("value")
- ctx.addMutableState(ctx.javaType(dataType), globalValue,
- s"$globalValue = ${ctx.defaultValue(dataType)};")
- val funcName = ctx.freshName(baseFuncName)
- val funcBody =
- s"""
- |private void $funcName(InternalRow ${ctx.INPUT_ROW}) {
- | ${ev.code.trim}
- | $globalIsNull = ${ev.isNull};
- | $globalValue = ${ev.value};
- |}
- """.stripMargin
- ctx.addNewFunction(funcName, funcBody)
- (funcName, globalIsNull, globalValue)
- }
-
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index f8fe774823e5b..0ab72074b480c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone}
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
* Common base class for time zone aware expressions.
*/
trait TimeZoneAwareExpression extends Expression {
+ /** The expression is only resolved when the time zone has been set. */
+ override lazy val resolved: Boolean =
+ childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined
/** the timezone ID to be used to evaluate value. */
def timeZoneId: Option[String]
@@ -41,7 +43,7 @@ trait TimeZoneAwareExpression extends Expression {
/** Returns a copy of this expression with the specified timeZoneId. */
def withTimeZone(timeZoneId: String): TimeZoneAwareExpression
- @transient lazy val timeZone: TimeZone = TimeZone.getTimeZone(timeZoneId.get)
+ @transient lazy val timeZone: TimeZone = DateTimeUtils.getTimeZone(timeZoneId.get)
}
/**
@@ -400,13 +402,15 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa
}
}
+// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(date) - Returns the week of the year of the given date.",
+ usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.",
extended = """
Examples:
> SELECT _FUNC_('2008-02-20');
8
""")
+// scalastyle:on line.size.limit
case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -414,7 +418,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
override def dataType: DataType = IntegerType
@transient private lazy val c = {
- val c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
+ val c = Calendar.getInstance(DateTimeUtils.getTimeZone("UTC"))
c.setFirstDayOfWeek(Calendar.MONDAY)
c.setMinimalDaysInFirstWeek(4)
c
@@ -429,9 +433,10 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val c = ctx.freshName("cal")
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(cal, c,
s"""
- $c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC"));
+ $c = $cal.getInstance($dtu.getTimeZone("UTC"));
$c.setFirstDayOfWeek($cal.MONDAY);
$c.setMinimalDaysInFirstWeek(4);
""")
@@ -952,8 +957,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
val tzTerm = ctx.freshName("tz")
val utcTerm = ctx.freshName("utc")
val tzClass = classOf[TimeZone].getName
- ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""")
- ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""")
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
+ ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
@@ -1123,8 +1129,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
val tzTerm = ctx.freshName("tz")
val utcTerm = ctx.freshName("utc")
val tzClass = classOf[TimeZone].getName
- ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""")
- ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""")
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
+ ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index c2211ae5d594b..752dea23e1f7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
override def nullable: Boolean = true
override def nullSafeEval(input: Any): Any =
- input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
+ input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 2a5963d37f5e8..be61b52bd4d91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -388,9 +388,10 @@ abstract class HashExpression[E] extends Expression {
input: String,
result: String,
fields: Array[StructField]): String = {
- fields.zipWithIndex.map { case (field, index) =>
+ val hashes = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
- }.mkString("\n")
+ }
+ ctx.splitExpressions(input, hashes)
}
@tailrec
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index df4d406b84d60..5ede335dcbc50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter}
import scala.util.parsing.combinator.RegexParsers
@@ -149,7 +149,9 @@ case class GetJsonObject(json: Expression, path: Expression)
if (parsed.isDefined) {
try {
- Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser =>
+ /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson
+ detect character encoding which could fail for some malformed strings */
+ Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser =>
val output = new ByteArrayOutputStream()
val matched = Utils.tryWithResource(
jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator =>
@@ -393,8 +395,10 @@ case class JsonTuple(children: Seq[Expression])
}
try {
- Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) {
- parser => parseRow(parser, input)
+ /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson
+ detect character encoding which could fail for some malformed strings */
+ Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
+ parseRow(parser, input)
}
} catch {
case _: JsonProcessingException =>
@@ -602,7 +606,7 @@ case class JsonToStructs(
{"a":1,"b":2}
> SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
{"time":"26/08/2015"}
- > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2));
+ > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)));
[{"a":1,"b":2}]
""")
// scalastyle:on line.size.limit
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index c4d47ab2084fd..42d668958d2ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -232,18 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
}
override def inputTypes: Seq[AbstractDataType] =
- Seq(TypeCollection(DoubleType, DecimalType))
+ Seq(TypeCollection(DoubleType, DecimalType, LongType))
protected override def nullSafeEval(input: Any): Any = child.dataType match {
+ case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
- case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
+ case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
- case DecimalType.Fixed(precision, scale) =>
+ case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.ceil()")
+ case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
@@ -281,7 +283,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
> SELECT _FUNC_('100', 2, 10);
4
> SELECT _FUNC_(-10, 16, -10);
- 16
+ -16
""")
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -347,18 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
}
override def inputTypes: Seq[AbstractDataType] =
- Seq(TypeCollection(DoubleType, DecimalType))
+ Seq(TypeCollection(DoubleType, DecimalType, LongType))
protected override def nullSafeEval(input: Any): Any = child.dataType match {
+ case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
- case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
+ case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
- case DecimalType.Fixed(precision, scale) =>
+ case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.floor()")
+ case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
@@ -966,7 +970,7 @@ case class Logarithm(left: Expression, right: Expression)
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
- * @param mode rounding mode (e.g. HALF_UP, HALF_UP)
+ * @param mode rounding mode (e.g. HALF_UP, HALF_EVEN)
* @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN")
*/
abstract class RoundBase(child: Expression, scale: Expression,
@@ -1023,10 +1027,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
- child.dataType match {
- case _: DecimalType =>
+ dataType match {
+ case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal]
- decimal.toPrecision(decimal.precision, _scale, mode).orNull
+ decimal.toPrecision(decimal.precision, s, mode)
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
@@ -1055,15 +1059,11 @@ abstract class RoundBase(child: Expression, scale: Expression,
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val ce = child.genCode(ctx)
- val evaluationCode = child.dataType match {
- case _: DecimalType =>
+ val evaluationCode = dataType match {
+ case DecimalType.Fixed(_, s) =>
s"""
- if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale},
- java.math.BigDecimal.${modeStr})) {
- ${ev.value} = ${ce.value};
- } else {
- ${ev.isNull} = true;
- }"""
+ ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
+ ${ev.isNull} = ${ev.value} == null;"""
case ByteType =>
if (_scale < 0) {
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 92036b727dbbd..00bb1a154739f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -71,14 +71,10 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val first = children(0)
- val rest = children.drop(1)
- val firstEval = first.genCode(ctx)
- ev.copy(code = s"""
- ${firstEval.code}
- boolean ${ev.isNull} = ${firstEval.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" +
- rest.map { e =>
+ ctx.addMutableState("boolean", ev.isNull, "")
+ ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
+
+ val evals = children.map { e =>
val eval = e.genCode(ctx)
s"""
if (${ev.isNull}) {
@@ -89,7 +85,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
}
"""
- }.mkString("\n"))
+ }
+
+ ev.copy(code = s"""
+ ${ev.isNull} = true;
+ ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""")
}
}
@@ -356,7 +357,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val nonnull = ctx.freshName("nonnull")
- val code = children.map { e =>
+ val evals = children.map { e =>
val eval = e.genCode(ctx)
e.dataType match {
case DoubleType | FloatType =>
@@ -378,7 +379,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}
"""
}
- }.mkString("\n")
+ }
+
+ val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+ evals.mkString("\n")
+ } else {
+ ctx.splitExpressions(evals, "atLeastNNonNulls",
+ ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil,
+ returnType = "int",
+ makeSplitFunction = { body =>
+ s"""
+ $body
+ return $nonnull;
+ """
+ },
+ foldFunctions = { funcCalls =>
+ funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n")
+ }
+ )
+ }
+
ev.copy(code = s"""
int $nonnull = 0;
$code
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index f446c3e4a75f6..c5237660f0c7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection.universe.newTermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
@@ -189,11 +190,13 @@ case class Invoke(
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ private lazy val encodedFunctionName = newTermName(functionName).encodedName.toString
+
@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
- val m = cls.getMethods.find(_.getName == functionName)
+ val m = cls.getMethods.find(_.getName == encodedFunctionName)
if (m.isEmpty) {
- sys.error(s"Couldn't find $functionName on $cls")
+ sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else {
m
}
@@ -222,7 +225,7 @@ case class Invoke(
}
val evaluate = if (returnPrimitive) {
- getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
+ getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
// If the function can return null, we do an extra check to make sure our null bit is still
@@ -240,7 +243,7 @@ case class Invoke(
}
s"""
Object $funcResult = null;
- ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
+ ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")}
$assignResult
"""
}
@@ -451,6 +454,8 @@ object MapObjects {
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
+ * @param elementNullable When false, indicating elements in the collection are always
+ * non-null value.
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
* or None (returning ArrayType)
*/
@@ -458,11 +463,12 @@ object MapObjects {
function: Expression => Expression,
inputData: Expression,
elementType: DataType,
+ elementNullable: Boolean = true,
customCollectionCls: Option[Class[_]] = None): MapObjects = {
val id = curId.getAndIncrement()
val loopValue = s"MapObjects_loopValue$id"
val loopIsNull = s"MapObjects_loopIsNull$id"
- val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
+ val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable)
MapObjects(
loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls)
}
@@ -656,18 +662,21 @@ object ExternalMapToCatalyst {
inputMap: Expression,
keyType: DataType,
keyConverter: Expression => Expression,
+ keyNullable: Boolean,
valueType: DataType,
valueConverter: Expression => Expression,
valueNullable: Boolean): ExternalMapToCatalyst = {
val id = curId.getAndIncrement()
val keyName = "ExternalMapToCatalyst_key" + id
+ val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id
val valueName = "ExternalMapToCatalyst_value" + id
val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id
ExternalMapToCatalyst(
keyName,
+ keyIsNull,
keyType,
- keyConverter(LambdaVariable(keyName, "false", keyType, false)),
+ keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)),
valueName,
valueIsNull,
valueType,
@@ -683,6 +692,8 @@ object ExternalMapToCatalyst {
*
* @param key the name of the map key variable that used when iterate the map, and used as input for
* the `keyConverter`
+ * @param keyIsNull the nullability of the map key variable that used when iterate the map, and
+ * used as input for the `keyConverter`
* @param keyType the data type of the map key variable that used when iterate the map, and used as
* input for the `keyConverter`
* @param keyConverter A function that take the `key` as input, and converts it to catalyst format.
@@ -698,6 +709,7 @@ object ExternalMapToCatalyst {
*/
case class ExternalMapToCatalyst private(
key: String,
+ keyIsNull: String,
keyType: DataType,
keyConverter: Expression,
value: String,
@@ -726,6 +738,13 @@ case class ExternalMapToCatalyst private(
val entry = ctx.freshName("entry")
val entries = ctx.freshName("entries")
+ val keyElementJavaType = ctx.javaType(keyType)
+ val valueElementJavaType = ctx.javaType(valueType)
+ ctx.addMutableState("boolean", keyIsNull, "")
+ ctx.addMutableState(keyElementJavaType, key, "")
+ ctx.addMutableState("boolean", valueIsNull, "")
+ ctx.addMutableState(valueElementJavaType, value, "")
+
val (defineEntries, defineKeyValue) = child.dataType match {
case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
val javaIteratorCls = classOf[java.util.Iterator[_]].getName
@@ -737,8 +756,8 @@ case class ExternalMapToCatalyst private(
val defineKeyValue =
s"""
final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
- ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey();
- ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue();
+ $key = (${ctx.boxedType(keyType)}) $entry.getKey();
+ $value = (${ctx.boxedType(valueType)}) $entry.getValue();
"""
defineEntries -> defineKeyValue
@@ -752,17 +771,23 @@ case class ExternalMapToCatalyst private(
val defineKeyValue =
s"""
final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
- ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1();
- ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2();
+ $key = (${ctx.boxedType(keyType)}) $entry._1();
+ $value = (${ctx.boxedType(valueType)}) $entry._2();
"""
defineEntries -> defineKeyValue
}
+ val keyNullCheck = if (ctx.isPrimitiveType(keyType)) {
+ s"$keyIsNull = false;"
+ } else {
+ s"$keyIsNull = $key == null;"
+ }
+
val valueNullCheck = if (ctx.isPrimitiveType(valueType)) {
- s"boolean $valueIsNull = false;"
+ s"$valueIsNull = false;"
} else {
- s"boolean $valueIsNull = $value == null;"
+ s"$valueIsNull = $value == null;"
}
val arrayCls = classOf[GenericArrayData].getName
@@ -781,6 +806,7 @@ case class ExternalMapToCatalyst private(
$defineEntries
while($entries.hasNext()) {
$defineKeyValue
+ $keyNullCheck
$valueNullCheck
${genKeyConverter.code}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 5034566132f7a..02fb262ec845e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,23 +17,26 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.immutable.TreeSet
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
object InterpretedPredicate {
- def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) =
+ def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate =
create(BindReferences.bindReference(expression, inputSchema))
- def create(expression: Expression): (InternalRow => Boolean) = {
- (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
- }
+ def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression)
}
+case class InterpretedPredicate(expression: Expression) extends BasePredicate {
+ override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean]
+}
/**
* An [[Expression]] that returns a boolean value.
@@ -162,19 +165,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
case _ =>
- if (list.exists(l => l.dataType != value.dataType)) {
- TypeCheckResult.TypeCheckFailure("Arguments must be same type")
+ val mismatchOpt = list.find(l => l.dataType != value.dataType)
+ if (mismatchOpt.isDefined) {
+ TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
+ s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
}
override def children: Seq[Expression] = value +: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
+ private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@@ -189,10 +195,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
- if (v == evaluatedValue) {
- return true
- } else if (v == null) {
+ if (v == null) {
hasNull = true
+ } else if (ordering.equiv(v, evaluatedValue)) {
+ return true
}
}
if (hasNull) {
@@ -206,24 +212,34 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
+ ctx.addMutableState("boolean", ev.value, "")
+ ctx.addMutableState("boolean", ev.isNull, "")
+ val valueArg = ctx.freshName("valueArg")
val listCode = listGen.map(x =>
s"""
if (!${ev.value}) {
${x.code}
if (${x.isNull}) {
${ev.isNull} = true;
- } else if (${ctx.genEqual(value.dataType, valueGen.value, x.value)}) {
+ } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
${ev.isNull} = false;
${ev.value} = true;
}
}
- """).mkString("\n")
+ """)
+ val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil
+ ctx.splitExpressions(listCode, "valueIn", args)
+ } else {
+ listCode.mkString("\n")
+ }
ev.copy(code = s"""
${valueGen.code}
- boolean ${ev.value} = false;
- boolean ${ev.isNull} = ${valueGen.isNull};
+ ${ev.value} = false;
+ ${ev.isNull} = ${valueGen.isNull};
if (!${ev.isNull}) {
- $listCode
+ ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value};
+ $listCodes
}
""")
}
@@ -251,7 +267,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull
protected override def nullSafeEval(value: Any): Any = {
- if (hset.contains(value)) {
+ if (set.contains(value)) {
true
} else if (hasNull) {
null
@@ -260,27 +276,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}
- def getHSet(): Set[Any] = hset
+ @transient private[this] lazy val set = child.dataType match {
+ case _: AtomicType => hset
+ case _: NullType => hset
+ case _ =>
+ // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
+ TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
+ }
+
+ def getSet(): Set[Any] = set
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.genCode(ctx)
ctx.references += this
- val hsetTerm = ctx.freshName("hset")
- val hasNullTerm = ctx.freshName("hasNull")
- ctx.addMutableState(setName, hsetTerm,
- s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
- ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
+ val setTerm = ctx.freshName("set")
+ val setNull = if (hasNull) {
+ s"""
+ |if (!${ev.value}) {
+ | ${ev.isNull} = true;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+ ctx.addMutableState(setName, setTerm,
+ s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.value} = false;
if (!${ev.isNull}) {
- ${ev.value} = $hsetTerm.contains(${childGen.value});
- if (!${ev.value} && $hasNullTerm) {
- ${ev.isNull} = true;
- }
+ ${ev.value} = $setTerm.contains(${childGen.value});
+ $setNull
}
""")
}
@@ -325,7 +354,46 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
val eval2 = right.genCode(ctx)
// The result should be `false`, if any of them is `false` whenever the other is null or not.
- if (!left.nullable && !right.nullable) {
+
+ // place generated code of eval1 and eval2 in separate methods if their code combined is large
+ val combinedLength = eval1.code.length + eval2.code.length
+ if (combinedLength > 1024 &&
+ // Split these expressions only if they are created from a row object
+ (ctx.INPUT_ROW != null && ctx.currentVars == null)) {
+
+ val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) =
+ ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr")
+ val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) =
+ ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr")
+ if (!left.nullable && !right.nullable) {
+ val generatedCode = s"""
+ $eval1FuncName(${ctx.INPUT_ROW});
+ boolean ${ev.value} = false;
+ if (${eval1GlobalValue}) {
+ $eval2FuncName(${ctx.INPUT_ROW});
+ ${ev.value} = ${eval2GlobalValue};
+ }
+ """
+ ev.copy(code = generatedCode, isNull = "false")
+ } else {
+ val generatedCode = s"""
+ $eval1FuncName(${ctx.INPUT_ROW});
+ boolean ${ev.isNull} = false;
+ boolean ${ev.value} = false;
+ if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) {
+ } else {
+ $eval2FuncName(${ctx.INPUT_ROW});
+ if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) {
+ } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) {
+ ${ev.value} = true;
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ ev.copy(code = generatedCode)
+ }
+ } else if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
${eval1.code}
boolean ${ev.value} = false;
@@ -388,7 +456,46 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
val eval2 = right.genCode(ctx)
// The result should be `true`, if any of them is `true` whenever the other is null or not.
- if (!left.nullable && !right.nullable) {
+
+ // place generated code of eval1 and eval2 in separate methods if their code combined is large
+ val combinedLength = eval1.code.length + eval2.code.length
+ if (combinedLength > 1024 &&
+ // Split these expressions only if they are created from a row object
+ (ctx.INPUT_ROW != null && ctx.currentVars == null)) {
+
+ val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) =
+ ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr")
+ val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) =
+ ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr")
+ if (!left.nullable && !right.nullable) {
+ val generatedCode = s"""
+ $eval1FuncName(${ctx.INPUT_ROW});
+ boolean ${ev.value} = true;
+ if (!${eval1GlobalValue}) {
+ $eval2FuncName(${ctx.INPUT_ROW});
+ ${ev.value} = ${eval2GlobalValue};
+ }
+ """
+ ev.copy(code = generatedCode, isNull = "false")
+ } else {
+ val generatedCode = s"""
+ $eval1FuncName(${ctx.INPUT_ROW});
+ boolean ${ev.isNull} = false;
+ boolean ${ev.value} = true;
+ if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) {
+ } else {
+ $eval2FuncName(${ctx.INPUT_ROW});
+ if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) {
+ } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) {
+ ${ev.value} = false;
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ ev.copy(code = generatedCode)
+ }
+ } else if (!left.nullable && !right.nullable) {
ev.isNull = "false"
ev.copy(code = s"""
${eval1.code}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 3fa84589e3c68..aa5a1b5448c6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression
escape character, the following character is matched literally. It is invalid to escape
any other character.
+ Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order
+ to match "\abc", the pattern should be "\\abc".
+
+ When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it fallbacks
+ to Spark 1.6 behavior regarding string literal parsing. For example, if the config is
+ enabled, the pattern to match "\abc" should be "\abc".
+
Examples:
> SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%'
true
@@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
}
@ExpressionDescription(
- usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.")
+ usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.",
+ extended = """
+ Arguments:
+ str - a string expression
+ regexp - a string expression. The pattern string should be a Java regular expression.
+
+ Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser.
+ For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$".
+
+ There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to fallback
+ to the Spark 1.6 behavior regarding string literal parsing. For example, if the config is
+ enabled, the `regexp` that can match "\abc" is "^\abc$".
+
+ Examples:
+ When spark.sql.parser.escapedStringLiterals is disabled (default).
+ > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*'
+ true
+
+ When spark.sql.parser.escapedStringLiterals is enabled.
+ > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\Users.*'
+ true
+
+ See also:
+ Use LIKE to match with simple string pattern.
+""")
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
override def escape(v: String): String = v
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 5598a146997ca..767b59c03c32d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -62,15 +62,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
- val inputs = evals.map { eval =>
- s"${eval.isNull} ? null : ${eval.value}"
- }.mkString(", ")
- ev.copy(evals.map(_.code).mkString("\n") + s"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = UTF8String.concat($inputs);
- if (${ev.value} == null) {
- ${ev.isNull} = true;
- }
+ val args = ctx.freshName("args")
+
+ val inputs = evals.zipWithIndex.map { case (eval, index) =>
+ s"""
+ ${eval.code}
+ if (!${eval.isNull}) {
+ $args[$index] = ${eval.value};
+ }
+ """
+ }
+ val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ ctx.splitExpressions(inputs, "valueConcat",
+ ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
+ } else {
+ inputs.mkString("\n")
+ }
+ ev.copy(s"""
+ UTF8String[] $args = new UTF8String[${evals.length}];
+ $codes
+ UTF8String ${ev.value} = UTF8String.concat($args);
+ boolean ${ev.isNull} = ${ev.value} == null;
""")
}
}
@@ -124,13 +136,34 @@ case class ConcatWs(children: Seq[Expression])
if (children.forall(_.dataType == StringType)) {
// All children are strings. In that case we can construct a fixed size array.
val evals = children.map(_.genCode(ctx))
-
- val inputs = evals.map { eval =>
- s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
- }.mkString(", ")
-
- ev.copy(evals.map(_.code).mkString("\n") + s"""
- UTF8String ${ev.value} = UTF8String.concatWs($inputs);
+ val separator = evals.head
+ val strings = evals.tail
+ val numArgs = strings.length
+ val args = ctx.freshName("args")
+
+ val inputs = strings.zipWithIndex.map { case (eval, index) =>
+ if (eval.isNull != "true") {
+ s"""
+ ${eval.code}
+ if (!${eval.isNull}) {
+ $args[$index] = ${eval.value};
+ }
+ """
+ } else {
+ ""
+ }
+ }
+ val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ ctx.splitExpressions(inputs, "valueConcatWs",
+ ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
+ } else {
+ inputs.mkString("\n")
+ }
+ ev.copy(s"""
+ UTF8String[] $args = new UTF8String[$numArgs];
+ ${separator.code}
+ $codes
+ UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
} else {
@@ -143,32 +176,63 @@ case class ConcatWs(children: Seq[Expression])
child.dataType match {
case StringType =>
("", // we count all the StringType arguments num at once below.
- s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};")
+ if (eval.isNull == "true") {
+ ""
+ } else {
+ s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
+ })
case _: ArrayType =>
val size = ctx.freshName("n")
- (s"""
- if (!${eval.isNull}) {
- $varargNum += ${eval.value}.numElements();
- }
- """,
- s"""
- if (!${eval.isNull}) {
- final int $size = ${eval.value}.numElements();
- for (int j = 0; j < $size; j ++) {
- $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
- }
+ if (eval.isNull == "true") {
+ ("", "")
+ } else {
+ (s"""
+ if (!${eval.isNull}) {
+ $varargNum += ${eval.value}.numElements();
+ }
+ """,
+ s"""
+ if (!${eval.isNull}) {
+ final int $size = ${eval.value}.numElements();
+ for (int j = 0; j < $size; j ++) {
+ $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
+ }
+ }
+ """)
}
- """)
}
}.unzip
- ev.copy(evals.map(_.code).mkString("\n") +
- s"""
+ val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
+ val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs",
+ ("InternalRow", ctx.INPUT_ROW) :: Nil,
+ "int",
+ { body =>
+ s"""
+ int $varargNum = 0;
+ $body
+ return $varargNum;
+ """
+ },
+ _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
+ val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs",
+ ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
+ "int",
+ { body =>
+ s"""
+ $body
+ return $idxInVararg;
+ """
+ },
+ _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
+ ev.copy(
+ s"""
+ $codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxInVararg = 0;
- ${varargCount.mkString("\n")}
+ $varargCounts
UTF8String[] $array = new UTF8String[$varargNum];
- ${varargBuild.mkString("\n")}
+ $varargBuilds
UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
boolean ${ev.isNull} = ${ev.value} == null;
""")
@@ -223,22 +287,52 @@ case class Elt(children: Seq[Expression])
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val index = indexExpr.genCode(ctx)
val strings = stringExprs.map(_.genCode(ctx))
+ val indexVal = ctx.freshName("index")
+ val stringVal = ctx.freshName("stringVal")
val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
s"""
case ${index + 1}:
- ${ev.value} = ${eval.isNull} ? null : ${eval.value};
+ ${eval.code}
+ $stringVal = ${eval.isNull} ? null : ${eval.value};
break;
"""
- }.mkString("\n")
- val indexVal = ctx.freshName("index")
- val stringArray = ctx.freshName("strings");
+ }
- ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s"""
- final int $indexVal = ${index.value};
- UTF8String ${ev.value} = null;
- switch ($indexVal) {
- $assignStringValue
+ val cases = ctx.buildCodeBlocks(assignStringValue)
+ val codes = if (cases.length == 1) {
+ s"""
+ UTF8String $stringVal = null;
+ switch ($indexVal) {
+ ${cases.head}
+ }
+ """
+ } else {
+ var prevFunc = "null"
+ for (c <- cases.reverse) {
+ val funcName = ctx.freshName("eltFunc")
+ val funcBody = s"""
+ private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) {
+ UTF8String $stringVal = null;
+ switch ($indexVal) {
+ $c
+ default:
+ return $prevFunc;
+ }
+ return $stringVal;
+ }
+ """
+ ctx.addNewFunction(funcName, funcBody)
+ prevFunc = s"$funcName(${ctx.INPUT_ROW}, $indexVal)"
}
+ s"UTF8String $stringVal = $prevFunc;"
+ }
+
+ ev.copy(
+ s"""
+ ${index.code}
+ final int $indexVal = ${index.value};
+ $codes
+ UTF8String ${ev.value} = $stringVal;
final boolean ${ev.isNull} = ${ev.value} == null;
""")
}
@@ -960,10 +1054,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val pattern = children.head.genCode(ctx)
val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx)))
- val argListCode = argListGen.map(_._2.code + "\n")
-
- val argListString = argListGen.foldLeft("")((s, v) => {
- val nullSafeString =
+ val argList = ctx.freshName("argLists")
+ val numArgLists = argListGen.length
+ val argListCode = argListGen.zipWithIndex.map { case(v, index) =>
+ val value =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
@@ -971,8 +1065,19 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
} else {
s"(${v._2.isNull}) ? null : ${v._2.value}"
}
- s + "," + nullSafeString
- })
+ s"""
+ ${v._2.code}
+ $argList[$index] = $value;
+ """
+ }
+ val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ ctx.splitExpressions(
+ expressions = argListCode,
+ funcName = "valueFormatString",
+ arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil)
+ } else {
+ argListCode.mkString("\n")
+ }
val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
@@ -983,10 +1088,11 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- ${argListCode.mkString}
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
- $form.format(${pattern.value}.toString() $argListString);
+ Object[] $argList = new Object[$numArgLists];
+ $argListCodes
+ $form.format(${pattern.value}.toString(), $argList);
${ev.value} = UTF8String.fromString($sb.toString());
}""")
}
@@ -1005,7 +1111,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
""",
extended = """
Examples:
- > SELECT initcap('sPark sql');
+ > SELECT _FUNC_('sPark sql');
Spark Sql
""")
case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
index e0ed03a68981a..025a388aacaa5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.json
-import java.io.InputStream
+import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text
@@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable {
val bb = record.getByteBuffer
assert(bb.hasArray)
- jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ val bain = new ByteArrayInputStream(
+ bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+
+ jsonFactory.createParser(new InputStreamReader(bain, "UTF-8"))
}
def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 23ba5ed4d50dc..1fd680ab64b5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -70,7 +70,7 @@ private[sql] class JSONOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
- val timeZone: TimeZone = TimeZone.getTimeZone(
+ val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
@@ -81,7 +81,7 @@ private[sql] class JSONOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
- val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
+ val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index ff6c93ae9815c..4ed6728994193 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.json
import java.io.ByteArrayOutputStream
-import java.util.Locale
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
@@ -126,16 +125,11 @@ class JacksonParser(
case VALUE_STRING =>
// Special case handling for NaN and Infinity.
- val value = parser.getText
- val lowerCaseValue = value.toLowerCase(Locale.ROOT)
- if (lowerCaseValue.equals("nan") ||
- lowerCaseValue.equals("infinity") ||
- lowerCaseValue.equals("-infinity") ||
- lowerCaseValue.equals("inf") ||
- lowerCaseValue.equals("-inf")) {
- value.toFloat
- } else {
- throw new RuntimeException(s"Cannot parse $value as FloatType.")
+ parser.getText match {
+ case "NaN" => Float.NaN
+ case "Infinity" => Float.PositiveInfinity
+ case "-Infinity" => Float.NegativeInfinity
+ case other => throw new RuntimeException(s"Cannot parse $other as FloatType.")
}
}
@@ -146,16 +140,11 @@ class JacksonParser(
case VALUE_STRING =>
// Special case handling for NaN and Infinity.
- val value = parser.getText
- val lowerCaseValue = value.toLowerCase(Locale.ROOT)
- if (lowerCaseValue.equals("nan") ||
- lowerCaseValue.equals("infinity") ||
- lowerCaseValue.equals("-infinity") ||
- lowerCaseValue.equals("inf") ||
- lowerCaseValue.equals("-inf")) {
- value.toDouble
- } else {
- throw new RuntimeException(s"Cannot parse $value as DoubleType.")
+ parser.getText match {
+ case "NaN" => Double.NaN
+ case "Infinity" => Double.PositiveInfinity
+ case "-Infinity" => Double.NegativeInfinity
+ case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala
index 3b23c6cd2816f..134d16e981a15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala
@@ -44,7 +44,9 @@ object JacksonUtils {
case at: ArrayType => verifyType(name, at.elementType)
- case mt: MapType => verifyType(name, mt.keyType)
+ // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when
+ // generating JSON, so we only care if the values are valid for JSON.
+ case mt: MapType => verifyType(name, mt.valueType)
case udt: UserDefinedType[_] => verifyType(name, udt.sqlType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d221b0611a892..fe668217a6a5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -113,17 +113,18 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
SimplifyCreateArrayOps,
SimplifyCreateMapOps) ++
extendedOperatorOptimizationRules: _*) ::
- Batch("Check Cartesian Products", Once,
- CheckCartesianProducts(conf)) ::
Batch("Join Reorder", Once,
CostBasedJoinReorder(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates(conf)) ::
- Batch("Typed Filter Optimization", fixedPoint,
+ Batch("Object Expressions Optimization", fixedPoint,
+ EliminateMapObjects,
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
PropagateEmptyRelation) ::
+ Batch("Check Cartesian Products", Once,
+ CheckCartesianProducts(conf)) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen(conf)) ::
Batch("RewriteSubquery", Once,
@@ -298,12 +299,11 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
// pushdown Limit.
case LocalLimit(exp, Union(children)) =>
LocalLimit(exp, Union(children.map(maybePushLimit(exp, _))))
- // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the
- // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side
- // because we need to ensure that rows from the limited side still have an opportunity to match
- // against all candidates from the non-limited side. We also need to ensure that this limit
- // pushdown rule will not eventually introduce limits on both sides if it is applied multiple
- // times. Therefore:
+ // Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to
+ // the left and right sides, respectively. It's not safe to push limits below FULL OUTER
+ // JOIN in the general case without a more invasive rewrite.
+ // We also need to ensure that this limit pushdown rule will not eventually introduce limits
+ // on both sides if it is applied multiple times. Therefore:
// - If one side is already limited, stack another limit on top if the new limit is smaller.
// The redundant limit will be collapsed by the CombineLimits rule.
// - If neither side is limited, limit the side that is estimated to be bigger.
@@ -311,19 +311,6 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
val newJoin = joinType match {
case RightOuter => join.copy(right = maybePushLimit(exp, right))
case LeftOuter => join.copy(left = maybePushLimit(exp, left))
- case FullOuter =>
- (left.maxRows, right.maxRows) match {
- case (None, None) =>
- if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) {
- join.copy(left = maybePushLimit(exp, left))
- } else {
- join.copy(right = maybePushLimit(exp, right))
- }
- case (Some(_), Some(_)) => join
- case (Some(_), None) => join.copy(left = maybePushLimit(exp, left))
- case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right))
-
- }
case _ => join
}
LocalLimit(exp, newJoin)
@@ -440,8 +427,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
g.copy(child = prunedChild(g.child, g.references))
// Turn off `join` for Generate if no column from it's child is used
- case p @ Project(_, g: Generate)
- if g.join && !g.outer && p.references.subsetOf(g.generatedSet) =>
+ case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
// Eliminate unneeded attributes from right side of a Left Existence Join.
@@ -768,7 +754,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
case filter @ Filter(condition, aggregate: Aggregate)
- if aggregate.aggregateExpressions.forall(_.deterministic) =>
+ if aggregate.aggregateExpressions.forall(_.deterministic)
+ && aggregate.groupingExpressions.nonEmpty =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
@@ -861,7 +848,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// Note that some operators (e.g. project, aggregate, union) are being handled separately
// (earlier in this rule).
case _: AppendColumns => true
- case _: BroadcastHint => true
+ case _: ResolvedHint => true
case _: Distinct => true
case _: Generate => true
case _: Pivot => true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 7400a01918c52..987cd7434b459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
* - Join with one or two empty children (including Intersect/Except).
* 2. Unary-node Logical Plans
* - Project/Filter/Sample/Join/Limit/Repartition with all empty children.
- * - Aggregate with all empty children and without AggregateFunction expressions like COUNT.
+ * - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
@@ -39,10 +38,6 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
- private def containsAggregateExpression(e: Expression): Boolean = {
- e.collectFirst { case _: AggregateFunction => () }.isDefined
- }
-
private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
@@ -68,8 +63,13 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _: LocalLimit => empty(p)
case _: Repartition => empty(p)
case _: RepartitionByExpression => empty(p)
- // AggregateExpressions like COUNT(*) return their results like 0.
- case Aggregate(_, ae, _) if !ae.exists(containsAggregateExpression) => empty(p)
+ // An aggregate with non-empty group expression will return one output row per group when the
+ // input to the aggregate is not empty. If the input to the aggregate is empty then all groups
+ // will be empty and thus the output will be empty.
+ //
+ // If the grouping expressions are empty, however, then the aggregate will always produce a
+ // single output row and thus we cannot propagate the EmptyRelation.
+ case Aggregate(ge, _, _) if ge.nonEmpty => empty(p)
// Generators like Hive-style UDTF may return their records within `close`.
case Generate(_: Explode, _, _, _, _, _) => empty(p)
case _ => p
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 8445ee06bd89b..f2334830f8d88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -153,6 +154,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
case TrueLiteral Or _ => TrueLiteral
case _ Or TrueLiteral => TrueLiteral
+ case a And b if Not(a).semanticEquals(b) => FalseLiteral
+ case a Or b if Not(a).semanticEquals(b) => TrueLiteral
+ case a And b if a.semanticEquals(Not(b)) => FalseLiteral
+ case a Or b if a.semanticEquals(Not(b)) => TrueLiteral
+
case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a
@@ -320,22 +326,27 @@ object LikeSimplification extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Like(input, Literal(pattern, StringType)) =>
- pattern.toString match {
- case startsWith(prefix) if !prefix.endsWith("\\") =>
- StartsWith(input, Literal(prefix))
- case endsWith(postfix) =>
- EndsWith(input, Literal(postfix))
- // 'a%a' pattern is basically same with 'a%' && '%a'.
- // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
- case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
- And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
- And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
- case contains(infix) if !infix.endsWith("\\") =>
- Contains(input, Literal(infix))
- case equalTo(str) =>
- EqualTo(input, Literal(str))
- case _ =>
- Like(input, Literal.create(pattern, StringType))
+ if (pattern == null) {
+ // If pattern is null, return null value directly, since "col like null" == null.
+ Literal(null, BooleanType)
+ } else {
+ pattern.toString match {
+ case startsWith(prefix) if !prefix.endsWith("\\") =>
+ StartsWith(input, Literal(prefix))
+ case endsWith(postfix) =>
+ EndsWith(input, Literal(postfix))
+ // 'a%a' pattern is basically same with 'a%' && '%a'.
+ // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
+ case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
+ And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)),
+ And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
+ case contains(infix) if !infix.endsWith("\\") =>
+ Contains(input, Literal(infix))
+ case equalTo(str) =>
+ EqualTo(input, Literal(str))
+ case _ =>
+ Like(input, Literal.create(pattern, StringType))
+ }
}
}
}
@@ -368,6 +379,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] {
case EqualNullSafe(Literal(null, _), r) => IsNull(r)
case EqualNullSafe(l, Literal(null, _)) => IsNull(l)
+ case AssertNotNull(c, _) if !c.nullable => c
+
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
val newChildren = children.filterNot(isNullLiteral)
@@ -469,7 +482,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
case _: Distinct => true
case _: AppendColumns => true
case _: AppendColumnsWithObject => true
- case _: BroadcastHint => true
+ case _: ResolvedHint => true
case _: RepartitionByExpression => true
case _: Repartition => true
case _: Sort => true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index 89e1dc9e322e0..af0837e36e8ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import java.util.TimeZone
-
import scala.collection.mutable
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
@@ -55,7 +53,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
case CurrentDate(Some(timeZoneId)) =>
currentDates.getOrElseUpdate(timeZoneId, {
Literal.create(
- DateTimeUtils.millisToDays(timestamp / 1000L, TimeZone.getTimeZone(timeZoneId)),
+ DateTimeUtils.millisToDays(timestamp / 1000L, DateTimeUtils.getTimeZone(timeZoneId)),
DateType)
})
case CurrentTimestamp() => currentTime
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index c3ab58744953d..2fe3039774423 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -134,8 +134,8 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred
val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet))
val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet))
- val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull)
- val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull)
+ lazy val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull)
+ lazy val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull)
join.joinType match {
case RightOuter if leftHasNonNullPredicate => Inner
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 257dbfac8c3e8..8cdc6425bcad8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Removes MapObjects when the following conditions are satisfied
+ * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output
+ * are primitive types with non-nullable
+ * 2. no custom collection class specified representation of data item.
+ */
+object EliminateMapObjects extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 2a3e07aebe709..28c5a9bea275e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -335,13 +335,14 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case ne => (ne.exprId, evalAggOnZeroTups(ne))
}.toMap
- case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
+ case _ =>
+ sys.error(s"Unexpected operator in scalar subquery: $lp")
}
val resultMap = evalPlan(plan)
// By convention, the scalar subquery result is the leftmost field.
- resultMap(plan.output.head.exprId)
+ resultMap.getOrElse(plan.output.head.exprId, None)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index e1db1ef5b8695..c15899cb230e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
@@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
* TableIdentifier.
*/
-class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
+class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
import ParserUtils._
+ def this() = this(new SQLConf())
+
protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
@@ -215,7 +218,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
protected def visitNonOptionalPartitionSpec(
ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
- visitPartitionSpec(ctx).mapValues(_.orNull).map(identity)
+ visitPartitionSpec(ctx).map {
+ case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx)
+ case (key, Some(value)) => key -> value
+ }
}
/**
@@ -400,7 +406,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val withWindow = withDistinct.optionalMap(windows)(withWindows)
// Hint
- withWindow.optionalMap(hint)(withHints)
+ hints.asScala.foldRight(withWindow)(withHints)
}
}
@@ -526,13 +532,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
- * Add a [[Hint]] to a logical plan.
+ * Add [[UnresolvedHint]]s to a logical plan.
*/
private def withHints(
ctx: HintContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
- val stmt = ctx.hintStatement
- Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query)
+ var plan = query
+ ctx.hintStatements.asScala.reverse.foreach { case stmt =>
+ plan = UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(expression), plan)
+ }
+ plan
}
/**
@@ -1024,6 +1033,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType))
}
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.argument.asScala.map(expression))
+ }
+
/**
* Create a [[First]] expression.
*/
@@ -1047,7 +1063,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Create the function call.
val name = ctx.qualifiedName.getText
val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
- val arguments = ctx.namedExpression().asScala.map(expression) match {
+ val arguments = ctx.argument.asScala.map(expression) match {
case Seq(UnresolvedStar(None))
if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct =>
// Transform COUNT(*) into COUNT(1).
@@ -1067,19 +1083,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
}
- /**
- * Create a current timestamp/date expression. These are different from regular function because
- * they do not require the user to specify braces when calling them.
- */
- override def visitTimeFunctionCall(ctx: TimeFunctionCallContext): Expression = withOrigin(ctx) {
- ctx.name.getType match {
- case SqlBaseParser.CURRENT_DATE =>
- CurrentDate()
- case SqlBaseParser.CURRENT_TIMESTAMP =>
- CurrentTimestamp()
- }
- }
-
/**
* Create a function database (optional) and name pair.
*/
@@ -1406,7 +1409,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Special characters can be escaped by using Hive/C-style escaping.
*/
private def createString(ctx: StringLiteralContext): String = {
- ctx.STRING().asScala.map(string).mkString
+ if (conf.escapedStringLiterals) {
+ ctx.STRING().asScala.map(stringWithoutUnescape).mkString
+ } else {
+ ctx.STRING().asScala.map(string).mkString
+ }
}
/**
@@ -1488,8 +1495,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case ("decimal", precision :: scale :: Nil) =>
DecimalType(precision.getText.toInt, scale.getText.toInt)
case (dt, params) =>
- throw new ParseException(
- s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx)
+ val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
+ throw new ParseException(s"DataType $dtStr is not supported.", ctx)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 80ab75cc17fab..8e2e973485e1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
/**
@@ -34,8 +35,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class AbstractSqlParser extends ParserInterface with Logging {
/** Creates/Resolves DataType for a given SQL string. */
- def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
- // TODO add this to the parser interface.
+ override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
}
@@ -50,8 +50,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
}
/** Creates FunctionIdentifier for a given SQL string. */
- def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser =>
- astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ parse(sqlText) { parser =>
+ astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+ }
}
/**
@@ -120,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
/**
* Concrete SQL parser for Catalyst-only SQL statements.
*/
+class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
+ val astBuilder = new AstBuilder(conf)
+}
+
+/** For test-only. */
object CatalystSqlParser extends AbstractSqlParser {
- val astBuilder = new AstBuilder
+ val astBuilder = new AstBuilder(new SQLConf())
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index db3598bde04d3..75240d2196222 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -17,30 +17,51 @@
package org.apache.spark.sql.catalyst.parser
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* Interface for a parser.
*/
+@DeveloperApi
trait ParserInterface {
- /** Creates LogicalPlan for a given SQL string. */
+ /**
+ * Parse a string to a [[LogicalPlan]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parsePlan(sqlText: String): LogicalPlan
- /** Creates Expression for a given SQL string. */
+ /**
+ * Parse a string to an [[Expression]].
+ */
+ @throws[ParseException]("Text cannot be parsed to an Expression")
def parseExpression(sqlText: String): Expression
- /** Creates TableIdentifier for a given SQL string. */
+ /**
+ * Parse a string to a [[TableIdentifier]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a TableIdentifier")
def parseTableIdentifier(sqlText: String): TableIdentifier
- /** Creates FunctionIdentifier for a given SQL string. */
+ /**
+ * Parse a string to a [[FunctionIdentifier]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier
/**
- * Creates StructType for a given SQL string, which is a comma separated list of field
- * definitions which will preserve the correct Hive metadata.
+ * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
+ * of field definitions which will preserve the correct Hive metadata.
*/
+ @throws[ParseException]("Text cannot be parsed to a schema")
def parseTableSchema(sqlText: String): StructType
+
+ /**
+ * Parse a string to a [[DataType]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a DataType")
+ def parseDataType(sqlText: String): DataType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 6fbc33fad735c..77fdaa8255aa6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -68,6 +68,12 @@ object ParserUtils {
/** Convert a string node into a string. */
def string(node: TerminalNode): String = unescapeSQLString(node.getText)
+ /** Convert a string node into a string without unescaping. */
+ def stringWithoutUnescape(node: TerminalNode): String = {
+ // STRING parser rule forces that the input always has quotes at the starting and ending.
+ node.getText.slice(1, node.getText.size - 1)
+ }
+
/** Get the origin (line and position) of the token. */
def position(token: Token): Origin = {
val opt = Option(token)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index d39b0ef7e1d8a..ef925f92ecc7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -65,8 +65,8 @@ object PhysicalOperation extends PredicateHelper {
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
- case BroadcastHint(child) =>
- collectProjectsAndFilters(child)
+ case h: ResolvedHint =>
+ collectProjectsAndFilters(h.child)
case other =>
(None, Nil, other, Map.empty)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 2fb65bd435507..d3f822bf7eb0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -81,11 +81,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
case _ => Seq.empty[Attribute]
}
- // Collect aliases from expressions, so we may avoid producing recursive constraints.
- private lazy val aliasMap = AttributeMap(
- (expressions ++ children.flatMap(_.expressions)).collect {
+ // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so
+ // we may avoid producing recursive constraints.
+ private lazy val aliasMap: AttributeMap[Expression] = AttributeMap(
+ expressions.collect {
case a: Alias => (a.toAttribute, a.child)
- })
+ } ++ children.flatMap(_.aliasMap))
/**
* Infers an additional set of constraints from a given set of equality constraints.
@@ -286,7 +287,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpression(e)
- case Some(e: Expression) => Some(transformExpression(e))
+ case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map(recursiveTransform)
@@ -320,7 +321,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
productIterator.flatMap {
case e: Expression => e :: Nil
- case Some(e: Expression) => e :: Nil
+ case s: Some[_] => seqToExpressions(s.toSeq)
case seq: Traversable[_] => seqToExpressions(seq)
case other => Nil
}.toSeq
@@ -423,7 +424,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
}
-object QueryPlan {
+object QueryPlan extends PredicateHelper {
/**
* Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
* with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
@@ -442,4 +443,17 @@ object QueryPlan {
}
}.canonicalized.asInstanceOf[T]
}
+
+ /**
+ * Composes the given predicates into a conjunctive predicate, which is normalized and reordered.
+ * Then returns a new sequence of predicates by splitting the conjunctive predicate.
+ */
+ def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = {
+ if (predicates.nonEmpty) {
+ val normalized = normalizeExprId(predicates.reduce(And), output)
+ splitConjunctivePredicates(normalized)
+ } else {
+ Nil
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 6bdcf490ca5c8..2ebb2ff323c6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -347,7 +347,7 @@ abstract class UnaryNode extends LogicalPlan {
}
// Don't propagate rowCount and attributeStats, since they are not estimated here.
- Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable)
+ Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index 3d4efef953a64..a64562b5dbd93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -46,13 +46,13 @@ import org.apache.spark.util.Utils
* defaults to the product of children's `sizeInBytes`.
* @param rowCount Estimated number of rows.
* @param attributeStats Statistics for Attributes.
- * @param isBroadcastable If true, output is small enough to be used in a broadcast join.
+ * @param hints Query hints.
*/
case class Statistics(
sizeInBytes: BigInt,
rowCount: Option[BigInt] = None,
attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil),
- isBroadcastable: Boolean = false) {
+ hints: HintInfo = HintInfo()) {
override def toString: String = "Statistics(" + simpleString + ")"
@@ -65,7 +65,7 @@ case class Statistics(
} else {
""
},
- s"isBroadcastable=$isBroadcastable"
+ s"hints=$hints"
).filter(_.nonEmpty).mkString(", ")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 3ad757ebba851..2eee94364d84e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -83,7 +83,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
- * given `generator` is empty. `outer` has no effect when `join` is false.
+ * given `generator` is empty.
* @param qualifier Qualifier for the attributes of generator(UDTF)
* @param generatorOutput The output schema of the Generator.
* @param child Children logical plan node
@@ -195,9 +195,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
val leftSize = left.stats(conf).sizeInBytes
val rightSize = right.stats(conf).sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
- val isBroadcastable = left.stats(conf).isBroadcastable || right.stats(conf).isBroadcastable
-
- Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable)
+ Statistics(
+ sizeInBytes = sizeInBytes,
+ hints = left.stats(conf).hints.resetForJoin())
}
}
@@ -364,7 +364,8 @@ case class Join(
case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size.
- super.computeStats(conf).copy(isBroadcastable = false)
+ val stats = super.computeStats(conf)
+ stats.copy(hints = stats.hints.resetForJoin())
}
if (conf.cboEnabled) {
@@ -375,26 +376,6 @@ case class Join(
}
}
-/**
- * A hint for the optimizer that we should broadcast the `child` if used in a join operator.
- */
-case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = child.output
-
- // set isBroadcastable to true so the child will be broadcasted
- override def computeStats(conf: SQLConf): Statistics =
- child.stats(conf).copy(isBroadcastable = true)
-}
-
-/**
- * A general hint for the child. This node will be eliminated post analysis.
- * A pair of (name, parameters).
- */
-case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode {
- override lazy val resolved: Boolean = false
- override def output: Seq[Attribute] = child.output
-}
-
/**
* Insert some data into a table. Note that this plan is unresolved and has to be replaced by the
* concrete implementations during analysis.
@@ -410,17 +391,20 @@ case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) exten
* would have Map('a' -> Some('1'), 'b' -> None).
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
- * @param ifNotExists If true, only write if the table or partition does not exist.
+ * @param ifPartitionNotExists If true, only write if the partition does not exist.
+ * Only valid for static partitions.
*/
case class InsertIntoTable(
table: LogicalPlan,
partition: Map[String, Option[String]],
query: LogicalPlan,
overwrite: Boolean,
- ifNotExists: Boolean)
+ ifPartitionNotExists: Boolean)
extends LogicalPlan {
- assert(overwrite || !ifNotExists)
- assert(partition.values.forall(_.nonEmpty) || !ifNotExists)
+ // IF NOT EXISTS is only valid in INSERT OVERWRITE
+ assert(overwrite || !ifPartitionNotExists)
+ // IF NOT EXISTS is only valid in static partitions
+ assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists)
// We don't want `table` in children as sometimes we don't want to transform it.
override def children: Seq[LogicalPlan] = query :: Nil
@@ -577,7 +561,7 @@ case class Aggregate(
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1),
- isBroadcastable = child.stats(conf).isBroadcastable)
+ hints = child.stats(conf).hints)
} else {
super.computeStats(conf)
}
@@ -704,7 +688,7 @@ case class Expand(
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
*
* @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should
- * exists in groupByExprs.
+ * exist in groupByExprs.
* @param groupByExprs The Group By expressions candidates.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
@@ -766,7 +750,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
rowCount = Some(rowCount),
- isBroadcastable = childStats.isBroadcastable)
+ hints = childStats.hints)
}
}
@@ -787,7 +771,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
Statistics(
sizeInBytes = 1,
rowCount = Some(0),
- isBroadcastable = childStats.isBroadcastable)
+ hints = childStats.hints)
} else {
// The output row count of LocalLimit should be the sum of row counts from each partition.
// However, since the number of partitions is not available here, we just use statistics of
@@ -838,7 +822,7 @@ case class Sample(
}
val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
// Don't propagate column stats, because we don't know the distribution after a sample operation
- Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
+ Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints)
}
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
new file mode 100644
index 0000000000000..d16fae56b3d4a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * A general hint for the child that is not yet resolved. This node is generated by the parser and
+ * should be removed This node will be eliminated post analysis.
+ * @param name the name of the hint
+ * @param parameters the parameters of the hint
+ * @param child the [[LogicalPlan]] on which this hint applies
+ */
+case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan)
+ extends UnaryNode {
+
+ override lazy val resolved: Boolean = false
+ override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]].
+ */
+case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+
+ override lazy val canonicalized: LogicalPlan = child.canonicalized
+
+ override def computeStats(conf: SQLConf): Statistics = {
+ val stats = child.stats(conf)
+ stats.copy(hints = hints)
+ }
+}
+
+
+case class HintInfo(
+ isBroadcastable: Option[Boolean] = None) {
+
+ /** Must be called when computing stats for a join operator to reset hints. */
+ def resetForJoin(): HintInfo = copy(
+ isBroadcastable = None
+ )
+
+ override def toString: String = {
+ if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) {
+ "none"
+ } else {
+ isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("")
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index 48b5fbb03ef1e..a0c23198451a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -56,7 +56,7 @@ object AggregateEstimation {
sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
- isBroadcastable = childStats.isBroadcastable))
+ hints = childStats.hints))
} else {
None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index f1aff62cb6af0..e5fcdf9039be9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -43,6 +43,18 @@ object EstimationUtils {
avgLen = dataType.defaultSize, maxLen = dataType.defaultSize)
}
+ /**
+ * Updates (scales down) the number of distinct values if the number of rows decreases after
+ * some operation (such as filter, join). Otherwise keep it unchanged.
+ */
+ def updateNdv(oldNumRows: BigInt, newNumRows: BigInt, oldNdv: BigInt): BigInt = {
+ if (newNumRows < oldNumRows) {
+ ceil(BigDecimal(oldNdv) * BigDecimal(newNumRows) / BigDecimal(oldNumRows))
+ } else {
+ oldNdv
+ }
+ }
+
def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt()
/** Get column stats for output attributes. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 4b6b3b14d9ac8..df190867189ec 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import scala.collection.immutable.HashSet
import scala.collection.mutable
-import scala.math.BigDecimal.RoundingMode
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -32,14 +32,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
private val childStats = plan.child.stats(catalystConf)
- /**
- * We will update the corresponding ColumnStats for a column after we apply a predicate condition.
- * For example, column c has [min, max] value as [0, 100]. In a range condition such as
- * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we
- * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50]
- * after we evaluate the second condition c <= 50.
- */
- private val colStatsMap = new ColumnStatsMap
+ private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)
/**
* Returns an option of Statistics for a Filter logical plan node.
@@ -53,24 +46,19 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
def estimate: Option[Statistics] = {
if (childStats.rowCount.isEmpty) return None
- // Save a mutable copy of colStats so that we can later change it recursively.
- colStatsMap.setInitValues(childStats.attributeStats)
-
// Estimate selectivity of this filter predicate, and update column stats if needed.
// For not-supported condition, set filter selectivity to a conservative estimate 100%
- val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0)
+ val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(BigDecimal(1.0))
- val newColStats = if (filterSelectivity == 0) {
+ val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
+ val newColStats = if (filteredRowCount == 0) {
// The output is empty, we don't need to keep column stats.
AttributeMap[ColumnStat](Nil)
} else {
- colStatsMap.toColumnStats
+ colStatsMap.outputColumnStats(rowsBeforeFilter = childStats.rowCount.get,
+ rowsAfterFilter = filteredRowCount)
}
-
- val filteredRowCount: BigInt =
- EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
- val filteredSizeInBytes: BigInt =
- EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats)
+ val filteredSizeInBytes: BigInt = getOutputSize(plan.output, filteredRowCount, newColStats)
Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount),
attributeStats = newColStats))
@@ -92,16 +80,17 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
- def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
+ def calculateFilterSelectivity(condition: Expression, update: Boolean = true)
+ : Option[BigDecimal] = {
condition match {
case And(cond1, cond2) =>
- val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0)
- val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0)
+ val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(BigDecimal(1.0))
+ val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(BigDecimal(1.0))
Some(percent1 * percent2)
case Or(cond1, cond2) =>
- val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0)
- val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0)
+ val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(BigDecimal(1.0))
+ val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(BigDecimal(1.0))
Some(percent1 + percent2 - (percent1 * percent2))
// Not-operator pushdown
@@ -143,7 +132,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
* @return an optional double value to show the percentage of rows meeting a given condition.
* It returns None if the condition is not supported.
*/
- def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = {
+ def calculateSingleCondition(condition: Expression, update: Boolean): Option[BigDecimal] = {
condition match {
case l: Literal =>
evaluateLiteral(l)
@@ -237,7 +226,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
def evaluateNullCheck(
attr: Attribute,
isNull: Boolean,
- update: Boolean): Option[Double] = {
+ update: Boolean): Option[BigDecimal] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
@@ -256,7 +245,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
} else {
colStat.copy(nullCount = 0)
}
- colStatsMap(attr) = newStats
+ colStatsMap.update(attr, newStats)
}
val percent = if (isNull) {
@@ -265,7 +254,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
1.0 - nullPercent
}
- Some(percent.toDouble)
+ Some(percent)
}
/**
@@ -283,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
op: BinaryComparison,
attr: Attribute,
literal: Literal,
- update: Boolean): Option[Double] = {
+ update: Boolean): Option[BigDecimal] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
@@ -317,7 +306,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
def evaluateEquality(
attr: Attribute,
literal: Literal,
- update: Boolean): Option[Double] = {
+ update: Boolean): Option[BigDecimal] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
@@ -341,10 +330,10 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
colStat.copy(distinctCount = 1, min = Some(literal.value),
max = Some(literal.value), nullCount = 0)
}
- colStatsMap(attr) = newStats
+ colStatsMap.update(attr, newStats)
}
- Some((1.0 / BigDecimal(ndv)).toDouble)
+ Some(1.0 / BigDecimal(ndv))
} else {
Some(0.0)
}
@@ -361,7 +350,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
* @param literal a literal value (or constant)
* @return an optional double value to show the percentage of rows meeting a given condition
*/
- def evaluateLiteral(literal: Literal): Option[Double] = {
+ def evaluateLiteral(literal: Literal): Option[BigDecimal] = {
literal match {
case Literal(null, _) => Some(0.0)
case FalseLiteral => Some(0.0)
@@ -386,7 +375,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
def evaluateInSet(
attr: Attribute,
hSet: Set[Any],
- update: Boolean): Option[Double] = {
+ update: Boolean): Option[BigDecimal] = {
if (!colStatsMap.contains(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
@@ -417,7 +406,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
if (update) {
val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin),
max = Some(newMax), nullCount = 0)
- colStatsMap(attr) = newStats
+ colStatsMap.update(attr, newStats)
}
// We assume the whole set since there is no min/max information for String/Binary type
@@ -425,13 +414,13 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
newNdv = ndv.min(BigInt(hSet.size))
if (update) {
val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0)
- colStatsMap(attr) = newStats
+ colStatsMap.update(attr, newStats)
}
}
// return the filter selectivity. Without advanced statistics such as histograms,
// we have to assume uniform distribution.
- Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble))
+ Some((BigDecimal(newNdv) / BigDecimal(ndv)).min(1.0))
}
/**
@@ -449,7 +438,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
op: BinaryComparison,
attr: Attribute,
literal: Literal,
- update: Boolean): Option[Double] = {
+ update: Boolean): Option[BigDecimal] = {
val colStat = colStatsMap(attr)
val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
@@ -518,7 +507,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val newValue = Some(literal.value)
var newMax = colStat.max
var newMin = colStat.min
- var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ var newNdv = ceil(ndv * percent)
if (newNdv < 1) newNdv = 1
op match {
@@ -532,11 +521,11 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
val newStats =
colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0)
- colStatsMap(attr) = newStats
+ colStatsMap.update(attr, newStats)
}
}
- Some(percent.toDouble)
+ Some(percent)
}
/**
@@ -557,7 +546,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
op: BinaryComparison,
attrLeft: Attribute,
attrRight: Attribute,
- update: Boolean): Option[Double] = {
+ update: Boolean): Option[BigDecimal] = {
if (!colStatsMap.contains(attrLeft)) {
logDebug("[CBO] No statistics for " + attrLeft)
@@ -654,10 +643,10 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
// Need to adjust new min/max after the filter condition is applied
val ndvLeft = BigDecimal(colStatLeft.distinctCount)
- var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ var newNdvLeft = ceil(ndvLeft * percent)
if (newNdvLeft < 1) newNdvLeft = 1
val ndvRight = BigDecimal(colStatRight.distinctCount)
- var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ var newNdvRight = ceil(ndvRight * percent)
if (newNdvRight < 1) newNdvRight = 1
var newMaxLeft = colStatLeft.max
@@ -750,24 +739,57 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
}
}
- Some(percent.toDouble)
+ Some(percent)
}
}
-class ColumnStatsMap {
- private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty
+/**
+ * This class contains the original column stats from child, and maintains the updated column stats.
+ * We will update the corresponding ColumnStats for a column after we apply a predicate condition.
+ * For example, column c has [min, max] value as [0, 100]. In a range condition such as
+ * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we
+ * evaluate the first condition c > 40. We also need to set the column's [min, max] value to
+ * [40, 50] after we evaluate the second condition c <= 50.
+ *
+ * @param originalMap Original column stats from child.
+ */
+case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) {
- def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = {
- baseMap.clear()
- baseMap ++= colStats.baseMap
- }
+ /** This map maintains the latest column stats. */
+ private val updatedMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty
- def contains(a: Attribute): Boolean = baseMap.contains(a.exprId)
+ def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a)
- def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2
+ /**
+ * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in
+ * originalMap, because updatedMap has the latest (updated) column stats.
+ */
+ def apply(a: Attribute): ColumnStat = {
+ if (updatedMap.contains(a.exprId)) {
+ updatedMap(a.exprId)._2
+ } else {
+ originalMap(a)
+ }
+ }
- def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats)
+ /** Updates column stats in updatedMap. */
+ def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats)
- def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq)
+ /**
+ * Collects updated column stats, and scales down ndv for other column stats if the number of rows
+ * decreases after this Filter operator.
+ */
+ def outputColumnStats(rowsBeforeFilter: BigInt, rowsAfterFilter: BigInt)
+ : AttributeMap[ColumnStat] = {
+ val newColumnStats = originalMap.map { case (attr, oriColStat) =>
+ // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows
+ // decreases; otherwise keep it unchanged.
+ val newNdv = EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter,
+ newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount)
+ val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat)
+ attr -> colStat.copy(distinctCount = newNdv)
+ }
+ AttributeMap(newColumnStats.toSeq)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index 3245a73c8a2eb..8ef905c45d50d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -217,32 +217,17 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging {
if (joinKeyStats.contains(a)) {
outputAttrStats += a -> joinKeyStats(a)
} else {
- val leftRatio = if (leftRows != 0) {
- BigDecimal(outputRows) / BigDecimal(leftRows)
- } else {
- BigDecimal(0)
- }
- val rightRatio = if (rightRows != 0) {
- BigDecimal(outputRows) / BigDecimal(rightRows)
- } else {
- BigDecimal(0)
- }
val oldColStat = oldAttrStats(a)
val oldNdv = oldColStat.distinctCount
- // We only change (scale down) the number of distinct values if the number of rows
- // decreases after join, because join won't produce new values even if the number of
- // rows increases.
- val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) {
- ceil(BigDecimal(oldNdv) * leftRatio)
- } else if (join.right.outputSet.contains(a) && rightRatio < 1) {
- ceil(BigDecimal(oldNdv) * rightRatio)
+ val newNdv = if (join.left.outputSet.contains(a)) {
+ updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv)
} else {
- oldNdv
+ updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv)
}
+ val newColStat = oldColStat.copy(distinctCount = newNdv)
// TODO: support nullCount updates for specific outer joins
- outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv)
+ outputAttrStats += a -> newColStat
}
-
}
outputAttrStats
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 6fc828f63f152..85b368c862630 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -122,7 +122,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
logDebug(
s"""
|=== Result of Batch ${batch.name} ===
- |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")}
+ |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")}
""".stripMargin)
} else {
logTrace(s"Batch ${batch.name} has no effect.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index cc4c0835954ba..ae5c513eb040b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
- val newChild1 = f(arg1.asInstanceOf[BaseType])
- val newChild2 = f(arg2.asInstanceOf[BaseType])
+ val newChild1 = if (containsChild(arg1)) {
+ f(arg1.asInstanceOf[BaseType])
+ } else {
+ arg1.asInstanceOf[BaseType]
+ }
+
+ val newChild2 = if (containsChild(arg2)) {
+ f(arg2.asInstanceOf[BaseType])
+ } else {
+ arg2.asInstanceOf[BaseType]
+ }
+
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
@@ -444,6 +454,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case None => Nil
case Some(null) => Nil
case Some(any) => any :: Nil
+ case table: CatalogTable =>
+ table.storage.serde match {
+ case Some(serde) => table.identifier :: serde :: Nil
+ case _ => table.identifier :: Nil
+ }
case other => other :: Nil
}.mkString(", ")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index eb6aad5b2d2bb..02cfa6e1b8afd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import java.util.{Calendar, Locale, TimeZone}
+import java.util.concurrent.ConcurrentHashMap
+import java.util.function.{Function => JFunction}
import javax.xml.bind.DatatypeConverter
import scala.annotation.tailrec
@@ -30,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String
* Helper functions for converting between internal and external date and time representations.
* Dates are exposed externally as java.sql.Date and are represented internally as the number of
* dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp
- * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond
+ * and are stored internally as longs, which are capable of storing timestamps with microsecond
* precision.
*/
object DateTimeUtils {
@@ -98,6 +100,15 @@ object DateTimeUtils {
sdf
}
+ private val computedTimeZones = new ConcurrentHashMap[String, TimeZone]
+ private val computeTimeZone = new JFunction[String, TimeZone] {
+ override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId)
+ }
+
+ def getTimeZone(timeZoneId: String): TimeZone = {
+ computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone)
+ }
+
def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = {
val sdf = new SimpleDateFormat(formatString, Locale.US)
sdf.setTimeZone(timeZone)
@@ -388,13 +399,14 @@ object DateTimeUtils {
digitsMilli += 1
}
- if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) {
- return None
+ // We are truncating the nanosecond part, which results in loss of precision
+ while (digitsMilli > 6) {
+ segments(6) /= 10
+ digitsMilli -= 1
}
- // Instead of return None, we truncate the fractional seconds to prevent inserting NULL
- if (segments(6) > 999999) {
- segments(6) = segments(6).toString.take(6).toInt
+ if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) {
+ return None
}
if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 ||
@@ -407,7 +419,7 @@ object DateTimeUtils {
Calendar.getInstance(timeZone)
} else {
Calendar.getInstance(
- TimeZone.getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d"))
+ getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d"))
}
c.set(Calendar.MILLISECOND, 0)
@@ -592,7 +604,14 @@ object DateTimeUtils {
*/
private[this] def getYearAndDayInYear(daysSince1970: SQLDate): (Int, Int) = {
// add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999)
- val daysNormalized = daysSince1970 + toYearZero
+ var daysSince1970Tmp = daysSince1970
+ // Since Julian calendar was replaced with the Gregorian calendar,
+ // the 10 days after Oct. 4 were skipped.
+ // (1582-10-04) -141428 days since 1970-01-01
+ if (daysSince1970 <= -141428) {
+ daysSince1970Tmp -= 10
+ }
+ val daysNormalized = daysSince1970Tmp + toYearZero
val numOfQuarterCenturies = daysNormalized / daysIn400Years
val daysInThis400 = daysNormalized % daysIn400Years + 1
val (years, dayInYear) = numYears(daysInThis400)
@@ -1027,7 +1046,7 @@ object DateTimeUtils {
* representation in their timezone.
*/
def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = {
- convertTz(time, TimeZoneGMT, TimeZone.getTimeZone(timeZone))
+ convertTz(time, TimeZoneGMT, getTimeZone(timeZone))
}
/**
@@ -1035,7 +1054,7 @@ object DateTimeUtils {
* string representation in their timezone.
*/
def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = {
- convertTz(time, TimeZone.getTimeZone(timeZone), TimeZoneGMT)
+ convertTz(time, getTimeZone(timeZone), TimeZoneGMT)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 7101ca5a17de9..45225779bffcb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -70,7 +70,9 @@ object TypeUtils {
def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
- val res = x(i).compareTo(y(i))
+ val v1 = x(i) & 0xff
+ val v2 = y(i) & 0xff
+ val res = v1 - v2
if (res != 0) return res
}
x.length - y.length
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2e1798e22b9fc..ebabd1a1396b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -196,6 +197,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
+ .internal()
+ .doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
+ "parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " +
+ "prior to Spark 2.0.")
+ .booleanConf
+ .createWithDefault(false)
+
val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema")
.doc("When true, the Parquet data source merges schemas collected from all data files, " +
"otherwise the schema is picked from the summary file or a random data file " +
@@ -260,8 +269,9 @@ object SQLConf {
val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class")
.doc("The output committer class used by Parquet. The specified class needs to be a " +
- "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " +
- "of org.apache.parquet.hadoop.ParquetOutputCommitter.")
+ "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " +
+ "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" +
+ "will never be created, irrespective of the value of parquet.enable.summary-metadata")
.internal()
.stringConf
.createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter")
@@ -295,7 +305,7 @@ object SQLConf {
val HIVE_MANAGE_FILESOURCE_PARTITIONS =
buildConf("spark.sql.hive.manageFilesourcePartitions")
.doc("When true, enable metastore partition management for file source tables as well. " +
- "This includes both datasource and converted Hive tables. When partition managment " +
+ "This includes both datasource and converted Hive tables. When partition management " +
"is enabled, datasource tables store partition in the Hive metastore, and use the " +
"metastore to prune partitions during query planning.")
.booleanConf
@@ -337,7 +347,8 @@ object SQLConf {
.createWithDefault(true)
val COLUMN_NAME_OF_CORRUPT_RECORD = buildConf("spark.sql.columnNameOfCorruptRecord")
- .doc("The name of internal column for storing raw/un-parsed JSON records that fail to parse.")
+ .doc("The name of internal column for storing raw/un-parsed JSON and CSV records that fail " +
+ "to parse.")
.stringConf
.createWithDefault("_corrupt_record")
@@ -421,6 +432,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases")
+ .doc("When true, aliases in a select list can be used in group by clauses. When false, " +
+ "an analysis exception is thrown in the case.")
+ .booleanConf
+ .createWithDefault(true)
+
// The output committer class used by data sources. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
val OUTPUT_COMMITTER_CLASS =
@@ -521,8 +538,7 @@ object SQLConf {
val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
- "encountering corrupted or non-existing and contents that have been read will still be " +
- "returned.")
+ "encountering corrupted files and the contents that have been read will still be returned.")
.booleanConf
.createWithDefault(false)
@@ -760,27 +776,59 @@ object SQLConf {
.stringConf
.createWithDefaultFunction(() => TimeZone.getDefault.getID)
+ val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
+ buildConf("spark.sql.windowExec.buffer.in.memory.threshold")
+ .internal()
+ .doc("Threshold for number of rows guaranteed to be held in memory by the window operator")
+ .intConf
+ .createWithDefault(4096)
+
val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.windowExec.buffer.spill.threshold")
.internal()
- .doc("Threshold for number of rows buffered in window operator")
+ .doc("Threshold for number of rows to be spilled by window operator")
.intConf
- .createWithDefault(4096)
+ .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+
+ val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
+ buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
+ .internal()
+ .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " +
+ "join operator")
+ .intConf
+ .createWithDefault(Int.MaxValue)
val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold")
.internal()
- .doc("Threshold for number of rows buffered in sort merge join operator")
+ .doc("Threshold for number of rows to be spilled by sort merge join operator")
.intConf
- .createWithDefault(Int.MaxValue)
+ .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+
+ val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
+ buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold")
+ .internal()
+ .doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " +
+ "product operator")
+ .intConf
+ .createWithDefault(4096)
val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD =
buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold")
.internal()
- .doc("Threshold for number of rows buffered in cartesian product operator")
+ .doc("Threshold for number of rows to be spilled by cartesian product operator")
.intConf
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+ val SQL_OPTIONS_REDACTION_PATTERN =
+ buildConf("spark.sql.redaction.options.regex")
+ .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " +
+ "information. The values of options whose names that match this regex will be redacted " +
+ "in the explain output. This redaction is applied on top of the global redaction " +
+ s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.")
+ .regexConf
+ .createWithDefault("(?i)url".r)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -911,6 +959,8 @@ class SQLConf extends Serializable with Logging {
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
+ def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
+
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
@@ -1003,6 +1053,8 @@ class SQLConf extends Serializable with Logging {
def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
+ def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES)
+
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
@@ -1019,11 +1071,19 @@ class SQLConf extends Serializable with Logging {
def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER)
+ def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
+
def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)
+ def sortMergeJoinExecBufferInMemoryThreshold: Int =
+ getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
+
def sortMergeJoinExecBufferSpillThreshold: Int =
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD)
+ def cartesianProductExecBufferInMemoryThreshold: Int =
+ getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD)
+
def cartesianProductExecBufferSpillThreshold: Int =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
@@ -1104,10 +1164,12 @@ class SQLConf extends Serializable with Logging {
* not set yet, return `defaultValue`.
*/
def getConfString(key: String, defaultValue: String): String = {
- val entry = sqlConfEntries.get(key)
- if (entry != null && defaultValue != "") {
- // Only verify configs in the SQLConf object
- entry.valueConverter(defaultValue)
+ if (defaultValue != null && defaultValue != "") {
+ val entry = sqlConfEntries.get(key)
+ if (entry != null) {
+ // Only verify configs in the SQLConf object
+ entry.valueConverter(defaultValue)
+ }
}
Option(settings.get(key)).getOrElse(defaultValue)
}
@@ -1129,6 +1191,17 @@ class SQLConf extends Serializable with Logging {
}.toSeq
}
+ /**
+ * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN.
+ */
+ def redactOptions(options: Map[String, String]): Map[String, String] = {
+ val regexes = Seq(
+ getConf(SQL_OPTIONS_REDACTION_PATTERN),
+ SECRET_REDACTION_PATTERN.readFrom(reader))
+
+ regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap
+ }
+
/**
* Return whether a given key is set in this [[SQLConf]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index af1a9cee2962a..c6c0a605d89ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -81,4 +81,10 @@ object StaticSQLConf {
"SQL configuration and the current database.")
.booleanConf
.createWithDefault(false)
+
+ val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
+ .doc("Name of the class used to configure Spark Session extensions. The class should " +
+ "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
+ .stringConf
+ .createOptional
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index e8f6884c025c2..6da4f28b12962 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -126,20 +126,36 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def set(decimal: BigDecimal): Decimal = {
this.decimalVal = decimal
this.longVal = 0L
- this._precision = decimal.precision
+ if (decimal.precision <= decimal.scale) {
+ // For Decimal, we expect the precision is equal to or large than the scale, however,
+ // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact
+ // result. For example, the precision of 0.01 equals to 1 based on the definition, but
+ // the scale is 2. The expected precision should be 3.
+ this._precision = decimal.scale + 1
+ } else {
+ this._precision = decimal.precision
+ }
this._scale = decimal.scale
this
}
/**
- * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0.
+ * If the value is not in the range of long, convert it to BigDecimal and
+ * the precision and scale are based on the converted value.
+ *
+ * This code avoids BigDecimal object allocation as possible to improve runtime efficiency
*/
def set(bigintval: BigInteger): Decimal = {
- this.decimalVal = null
- this.longVal = bigintval.longValueExact()
- this._precision = DecimalType.MAX_PRECISION
- this._scale = 0
- this
+ try {
+ this.decimalVal = null
+ this.longVal = bigintval.longValueExact()
+ this._precision = DecimalType.MAX_PRECISION
+ this._scale = 0
+ this
+ } catch {
+ case _: ArithmeticException =>
+ set(BigDecimal(bigintval))
+ }
}
/**
@@ -218,22 +234,17 @@ final class Decimal extends Ordered[Decimal] with Serializable {
changePrecision(precision, scale, ROUND_HALF_UP)
}
- def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
- case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
- case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
- }
-
/**
* Create new `Decimal` with given precision and scale.
*
- * @return `Some(decimal)` if successful or `None` if overflow would occur
+ * @return a non-null `Decimal` value if successful or `null` if overflow would occur.
*/
private[sql] def toPrecision(
precision: Int,
scale: Int,
- roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
+ roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
val copy = clone()
- if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
+ if (copy.changePrecision(precision, scale, roundMode)) copy else null
}
/**
@@ -241,8 +252,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*
* @return true if successful, false if overflow would occur
*/
- private[sql] def changePrecision(precision: Int, scale: Int,
- roundMode: BigDecimal.RoundingMode.Value): Boolean = {
+ private[sql] def changePrecision(
+ precision: Int,
+ scale: Int,
+ roundMode: BigDecimal.RoundingMode.Value): Boolean = {
// fast path for UnsafeProjection
if (precision == this.precision && scale == this.scale) {
return true
@@ -377,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def floor: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
- toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
- throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
+ val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
+ if (res == null) {
+ throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
+ }
+ res
}
def ceil: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
- toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
- throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
+ val res = toPrecision(newPrecision, 0, ROUND_CEILING)
+ if (res == null) {
+ throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
+ }
+ res
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
new file mode 100644
index 0000000000000..e881685ce6262
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import org.apache.spark.internal.Logging
+
+
+/**
+ * Utils for handling schemas.
+ *
+ * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]].
+ */
+private[spark] object SchemaUtils extends Logging {
+
+ /**
+ * Checks if input column names have duplicate identifiers. Prints a warning message if
+ * the duplication exists.
+ *
+ * @param columnNames column names to check
+ * @param colType column type name, used in a warning message
+ * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not
+ */
+ def checkColumnNameDuplication(
+ columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = {
+ val names = if (caseSensitiveAnalysis) {
+ columnNames
+ } else {
+ columnNames.map(_.toLowerCase)
+ }
+ if (names.distinct.length != names.length) {
+ val duplicateColumns = names.groupBy(identity).collect {
+ case (x, ys) if ys.length > 1 => s"`$x`"
+ }
+ logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " +
+ "You might need to assign different column names.")
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 70ad064f93ebc..c0e4e37ae9ed9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -23,8 +23,9 @@ import java.sql.{Date, Timestamp}
import scala.reflect.runtime.universe.typeOf
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow}
-import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -83,6 +84,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) {
def this(b: String, a: Int) = this(a, b, c = 1.0)
}
+case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
+
object TestingUDT {
@SQLUserDefinedType(udt = classOf[NestedStructUDT])
class NestedStruct(val a: Integer, val b: Long, val c: Double)
@@ -349,4 +352,23 @@ class ScalaReflectionSuite extends SparkFunSuite {
}
}
}
+
+ test("SPARK-22472: add null check for top-level primitive values") {
+ assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
+ assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
+ }
+
+ test("SPARK-22442: Generate correct field names for special characters") {
+ val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
+ 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
+ val deserializer = deserializerFor[SpecialCharAsFieldData]
+ assert(serializer.dataType(0).name == "field.1")
+ assert(serializer.dataType(1).name == "field 2")
+
+ val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
+ case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
+ }}
+ assert(argumentsFields(0) == Seq("field.1"))
+ assert(argumentsFields(1) == Seq("field 2"))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index d2ebca5a83dd3..5050318d96358 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max}
-import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -152,7 +153,7 @@ class AnalysisErrorSuite extends AnalysisTest {
"not supported within a window function" :: Nil)
errorTest(
- "distinct window function",
+ "distinct aggregate function in window",
testRelation2.select(
WindowExpression(
AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true),
@@ -162,6 +163,16 @@ class AnalysisErrorSuite extends AnalysisTest {
UnspecifiedFrame)).as('window)),
"Distinct window functions are not supported" :: Nil)
+ errorTest(
+ "distinct function",
+ CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"),
+ "hex does not support the modifier DISTINCT" :: Nil)
+
+ errorTest(
+ "distinct window function",
+ CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"),
+ "percent_rank does not support the modifier DISTINCT" :: Nil)
+
errorTest(
"nested aggregate functions",
testRelation.groupBy('a)(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 82015b1e0671c..08d9313894c2d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst.analysis
+import java.net.URI
import java.util.Locale
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
@@ -32,7 +33,10 @@ trait AnalysisTest extends PlanTest {
private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
- val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
+ catalog.createDatabase(
+ CatalogDatabase("default", "", new URI("loc"), Map.empty),
+ ignoreIfExists = false)
catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
new Analyzer(catalog, conf) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala
new file mode 100644
index 0000000000000..48a3ca2ccfb0b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+
+class DSLHintSuite extends AnalysisTest {
+ lazy val a = 'a.int
+ lazy val b = 'b.string
+ lazy val c = 'c.string
+ lazy val r1 = LocalRelation(a, b, c)
+
+ test("various hint parameters") {
+ comparePlans(
+ r1.hint("hint1"),
+ UnresolvedHint("hint1", Seq(), r1)
+ )
+
+ comparePlans(
+ r1.hint("hint1", 1, "a"),
+ UnresolvedHint("hint1", Seq(1, "a"), r1)
+ )
+
+ comparePlans(
+ r1.hint("hint1", 1, $"a"),
+ UnresolvedHint("hint1", Seq(1, $"a"), r1)
+ )
+
+ comparePlans(
+ r1.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")),
+ UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), r1)
+ )
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 8f43171f309a9..3df2530ece636 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -90,8 +90,14 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
checkType(Average(d1), DecimalType(6, 5))
checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
+ checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
+ checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
+ checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
+ checkType(Sum(Add(d1, d1)), DecimalType(13, 1))
}
test("Comparison operations") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index d101e2227462d..3d5148008c628 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -28,68 +28,70 @@ class ResolveHintsSuite extends AnalysisTest {
test("invalid hints should be ignored") {
checkAnalysis(
- Hint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")),
+ UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")),
testRelation,
caseSensitive = false)
}
test("case-sensitive or insensitive parameters") {
checkAnalysis(
- Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
- BroadcastHint(testRelation),
+ UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
+ ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)
checkAnalysis(
- Hint("MAPJOIN", Seq("table"), table("TaBlE")),
- BroadcastHint(testRelation),
+ UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
+ ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)
checkAnalysis(
- Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
- BroadcastHint(testRelation),
+ UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
+ ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = true)
checkAnalysis(
- Hint("MAPJOIN", Seq("table"), table("TaBlE")),
+ UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
testRelation,
caseSensitive = true)
}
test("multiple broadcast hint aliases") {
checkAnalysis(
- Hint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
- Join(BroadcastHint(testRelation), BroadcastHint(testRelation2), Inner, None),
+ UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
+ Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
+ ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None),
caseSensitive = false)
}
test("do not traverse past existing broadcast hints") {
checkAnalysis(
- Hint("MAPJOIN", Seq("table"), BroadcastHint(table("table").where('a > 1))),
- BroadcastHint(testRelation.where('a > 1)).analyze,
+ UnresolvedHint("MAPJOIN", Seq("table"),
+ ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))),
+ ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze,
caseSensitive = false)
}
test("should work for subqueries") {
checkAnalysis(
- Hint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
- BroadcastHint(testRelation),
+ UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
+ ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)
checkAnalysis(
- Hint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
- BroadcastHint(testRelation),
+ UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
+ ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)
// Negative case: if the alias doesn't match, don't match the original table name.
checkAnalysis(
- Hint("MAPJOIN", Seq("table"), table("table").as("tableAlias")),
+ UnresolvedHint("MAPJOIN", Seq("table"), table("table").as("tableAlias")),
testRelation,
caseSensitive = false)
}
test("do not traverse past subquery alias") {
checkAnalysis(
- Hint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)),
+ UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)),
testRelation.where('a > 1).analyze,
caseSensitive = false)
}
@@ -102,7 +104,8 @@ class ResolveHintsSuite extends AnalysisTest {
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
""".stripMargin
),
- BroadcastHint(testRelation.where('a > 1).select('a)).select('a).analyze,
+ ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true)))
+ .select('a).analyze,
caseSensitive = false)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
index f45a826869842..d0fe815052256 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}
/**
@@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
- val converted = ResolveInlineTables(conf).convert(table)
+ val withTimeZone = ResolveTimeZone(conf).apply(table)
+ val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
- assert(converted.output.map(_.dataType) == Seq(TimestampType))
- assert(converted.data.size == 1)
- assert(converted.data(0).getLong(0) == correct)
+ assert(output.map(_.dataType) == Seq(TimestampType))
+ assert(data.size == 1)
+ assert(data.head.getLong(0) == correct)
}
test("nullability inference in convert") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 011d09ff60641..06514ad65daec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest {
}
}
+ private val timeZoneResolver = ResolveTimeZone(new SQLConf)
+
+ private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
+ timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
+ }
+
test("WidenSetOperationTypes for except and intersect") {
val firstTable = LocalRelation(
AttributeReference("i", IntegerType)(),
@@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())
- val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
- val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
- val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
+ val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except]
+ val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
@@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("p", ByteType)(),
AttributeReference("q", DoubleType)())
- val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
- val unionRelation = wt(
+ val unionRelation = widenSetOperationTypes(
Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
assert(unionRelation.children.length == 4)
checkOutput(unionRelation.children.head, expectedTypes)
@@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest {
}
}
- val dp = TypeCoercion.WidenSetOperationTypes
-
val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
val right1 = LocalRelation(
AttributeReference("r", DecimalType(5, 5))())
val expectedType1 = Seq(DecimalType(10, 8))
- val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
- val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
- val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
+ val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union]
+ val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except]
+ val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect]
checkOutput(r1.children.head, expectedType1)
checkOutput(r1.children.last, expectedType1)
@@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest {
val plan2 = LocalRelation(
AttributeReference("r", rType)())
- val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
- val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
- val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
+ val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union]
+ val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except]
+ val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect]
checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))
- val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
- val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
- val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
+ val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union]
+ val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except]
+ val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect]
checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
@@ -994,6 +997,9 @@ class TypeCoercionSuite extends PlanTest {
ruleTest(PromoteStrings,
EqualTo(Literal(Array(1, 2)), Literal("123")),
EqualTo(Literal(Array(1, 2)), Literal("123")))
+ ruleTest(PromoteStrings,
+ GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))),
+ GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType)))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala
new file mode 100644
index 0000000000000..2539ea615ff92
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.catalog
+
+import java.net.URI
+import java.nio.file.{Files, Path}
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Test Suite for external catalog events
+ */
+class ExternalCatalogEventSuite extends SparkFunSuite {
+
+ protected def newCatalog: ExternalCatalog = new InMemoryCatalog()
+
+ private def testWithCatalog(
+ name: String)(
+ f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) {
+ val catalog = newCatalog
+ val recorder = mutable.Buffer.empty[ExternalCatalogEvent]
+ catalog.addListener(new ExternalCatalogEventListener {
+ override def onEvent(event: ExternalCatalogEvent): Unit = {
+ recorder += event
+ }
+ })
+ f(catalog, (expected: Seq[ExternalCatalogEvent]) => {
+ val actual = recorder.clone()
+ recorder.clear()
+ assert(expected === actual)
+ })
+ }
+
+ private def createDbDefinition(uri: URI): CatalogDatabase = {
+ CatalogDatabase(name = "db5", description = "", locationUri = uri, Map.empty)
+ }
+
+ private def createDbDefinition(): CatalogDatabase = {
+ createDbDefinition(preparePath(Files.createTempDirectory("db_")))
+ }
+
+ private def preparePath(path: Path): URI = path.normalize().toUri
+
+ testWithCatalog("database") { (catalog, checkEvents) =>
+ // CREATE
+ val dbDefinition = createDbDefinition()
+
+ catalog.createDatabase(dbDefinition, ignoreIfExists = false)
+ checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil)
+
+ catalog.createDatabase(dbDefinition, ignoreIfExists = true)
+ checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil)
+
+ intercept[AnalysisException] {
+ catalog.createDatabase(dbDefinition, ignoreIfExists = false)
+ }
+ checkEvents(CreateDatabasePreEvent("db5") :: Nil)
+
+ // DROP
+ intercept[AnalysisException] {
+ catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false)
+ }
+ checkEvents(DropDatabasePreEvent("db4") :: Nil)
+
+ catalog.dropDatabase("db5", ignoreIfNotExists = false, cascade = false)
+ checkEvents(DropDatabasePreEvent("db5") :: DropDatabaseEvent("db5") :: Nil)
+
+ catalog.dropDatabase("db4", ignoreIfNotExists = true, cascade = false)
+ checkEvents(DropDatabasePreEvent("db4") :: DropDatabaseEvent("db4") :: Nil)
+ }
+
+ testWithCatalog("table") { (catalog, checkEvents) =>
+ val path1 = Files.createTempDirectory("db_")
+ val path2 = Files.createTempDirectory(path1, "tbl_")
+ val uri1 = preparePath(path1)
+ val uri2 = preparePath(path2)
+
+ // CREATE
+ val dbDefinition = createDbDefinition(uri1)
+
+ val storage = CatalogStorageFormat.empty.copy(
+ locationUri = Option(uri2))
+ val tableDefinition = CatalogTable(
+ identifier = TableIdentifier("tbl1", Some("db5")),
+ tableType = CatalogTableType.MANAGED,
+ storage = storage,
+ schema = new StructType().add("id", "long"))
+
+ catalog.createDatabase(dbDefinition, ignoreIfExists = false)
+ checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil)
+
+ catalog.createTable(tableDefinition, ignoreIfExists = false)
+ checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil)
+
+ catalog.createTable(tableDefinition, ignoreIfExists = true)
+ checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil)
+
+ intercept[AnalysisException] {
+ catalog.createTable(tableDefinition, ignoreIfExists = false)
+ }
+ checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil)
+
+ // RENAME
+ catalog.renameTable("db5", "tbl1", "tbl2")
+ checkEvents(
+ RenameTablePreEvent("db5", "tbl1", "tbl2") ::
+ RenameTableEvent("db5", "tbl1", "tbl2") :: Nil)
+
+ intercept[AnalysisException] {
+ catalog.renameTable("db5", "tbl1", "tbl2")
+ }
+ checkEvents(RenameTablePreEvent("db5", "tbl1", "tbl2") :: Nil)
+
+ // DROP
+ intercept[AnalysisException] {
+ catalog.dropTable("db5", "tbl1", ignoreIfNotExists = false, purge = true)
+ }
+ checkEvents(DropTablePreEvent("db5", "tbl1") :: Nil)
+
+ catalog.dropTable("db5", "tbl2", ignoreIfNotExists = false, purge = true)
+ checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil)
+
+ catalog.dropTable("db5", "tbl2", ignoreIfNotExists = true, purge = true)
+ checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil)
+ }
+
+ testWithCatalog("function") { (catalog, checkEvents) =>
+ // CREATE
+ val dbDefinition = createDbDefinition()
+
+ val functionDefinition = CatalogFunction(
+ identifier = FunctionIdentifier("fn7", Some("db5")),
+ className = "",
+ resources = Seq.empty)
+
+ val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4")
+ val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier)
+
+ catalog.createDatabase(dbDefinition, ignoreIfExists = false)
+ checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil)
+
+ catalog.createFunction("db5", functionDefinition)
+ checkEvents(CreateFunctionPreEvent("db5", "fn7") :: CreateFunctionEvent("db5", "fn7") :: Nil)
+
+ intercept[AnalysisException] {
+ catalog.createFunction("db5", functionDefinition)
+ }
+ checkEvents(CreateFunctionPreEvent("db5", "fn7") :: Nil)
+
+ // RENAME
+ catalog.renameFunction("db5", "fn7", "fn4")
+ checkEvents(
+ RenameFunctionPreEvent("db5", "fn7", "fn4") ::
+ RenameFunctionEvent("db5", "fn7", "fn4") :: Nil)
+ intercept[AnalysisException] {
+ catalog.renameFunction("db5", "fn7", "fn4")
+ }
+ checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil)
+
+ // DROP
+ intercept[AnalysisException] {
+ catalog.dropFunction("db5", "fn7")
+ }
+ checkEvents(DropFunctionPreEvent("db5", "fn7") :: Nil)
+
+ catalog.dropFunction("db5", "fn4")
+ checkEvents(DropFunctionPreEvent("db5", "fn4") :: DropFunctionEvent("db5", "fn4") :: Nil)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
index 42db4398e5072..014d0c0e72ad7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
@@ -245,15 +245,12 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
test("alter table schema") {
val catalog = newBasicCatalog()
- val tbl1 = catalog.getTable("db2", "tbl1")
- val newSchema = StructType(Seq(
+ val newDataSchema = StructType(Seq(
StructField("new_field_1", IntegerType),
- StructField("new_field_2", StringType),
- StructField("a", IntegerType),
- StructField("b", StringType)))
- catalog.alterTableSchema("db2", "tbl1", newSchema)
+ StructField("new_field_2", StringType)))
+ catalog.alterTableDataSchema("db2", "tbl1", newDataSchema)
val newTbl1 = catalog.getTable("db2", "tbl1")
- assert(newTbl1.schema == newSchema)
+ assert(newTbl1.dataSchema == newDataSchema)
}
test("get table") {
@@ -439,6 +436,18 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty)
}
+ test("SPARK-21457: list partitions with special chars") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listPartitions("db2", "tbl1").isEmpty)
+
+ val part1 = CatalogTablePartition(Map("a" -> "1", "b" -> "i+j"), storageFormat)
+ val part2 = CatalogTablePartition(Map("a" -> "1", "b" -> "i.j"), storageFormat)
+ catalog.createPartitions("db2", "tbl1", Seq(part1, part2), ignoreIfExists = false)
+
+ assert(catalog.listPartitions("db2", "tbl1", Some(part1.spec)).map(_.spec) == Seq(part1.spec))
+ assert(catalog.listPartitions("db2", "tbl1", Some(part2.spec)).map(_.spec) == Seq(part2.spec))
+ }
+
test("list partitions by filter") {
val tz = TimeZone.getDefault.getID
val catalog = newBasicCatalog()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index be8903000a0d1..a8c6c06e1f4d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -452,9 +452,9 @@ abstract class SessionCatalogSuite extends PlanTest {
withBasicCatalog { sessionCatalog =>
sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false)
val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1")
- sessionCatalog.alterTableSchema(
+ sessionCatalog.alterTableDataSchema(
TableIdentifier("t1", Some("default")),
- StructType(oldTab.dataSchema.add("c3", IntegerType) ++ oldTab.partitionSchema))
+ StructType(oldTab.dataSchema.add("c3", IntegerType)))
val newTab = sessionCatalog.externalCatalog.getTable("default", "t1")
// construct the expected table schema
@@ -464,13 +464,26 @@ abstract class SessionCatalogSuite extends PlanTest {
}
}
+ test("alter table add columns which are conflicting with partition columns") {
+ withBasicCatalog { sessionCatalog =>
+ sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false)
+ val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1")
+ val e = intercept[AnalysisException] {
+ sessionCatalog.alterTableDataSchema(
+ TableIdentifier("t1", Some("default")),
+ StructType(oldTab.dataSchema.add("a", IntegerType)))
+ }.getMessage
+ assert(e.contains("Found duplicate column(s): a"))
+ }
+ }
+
test("alter table drop columns") {
withBasicCatalog { sessionCatalog =>
sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false)
val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1")
val e = intercept[AnalysisException] {
- sessionCatalog.alterTableSchema(
- TableIdentifier("t1", Some("default")), StructType(oldTab.schema.drop(1)))
+ sessionCatalog.alterTableDataSchema(
+ TableIdentifier("t1", Some("default")), StructType(oldTab.dataSchema.drop(1)))
}.getMessage
assert(e.contains("We don't support dropping columns yet."))
}
@@ -498,17 +511,6 @@ abstract class SessionCatalogSuite extends PlanTest {
}
}
- test("get option of table metadata") {
- withBasicCatalog { catalog =>
- assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2")))
- == Option(catalog.externalCatalog.getTable("db2", "tbl1")))
- assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty)
- intercept[NoSuchDatabaseException] {
- catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db")))
- }
- }
- }
-
test("lookup table relation") {
withBasicCatalog { catalog =>
val tempTable1 = Range(1, 10, 1, 10)
@@ -517,14 +519,14 @@ abstract class SessionCatalogSuite extends PlanTest {
catalog.setCurrentDatabase("db2")
// If we explicitly specify the database, we'll look up the relation in that database
assert(catalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head
- .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1)
+ .asInstanceOf[UnresolvedCatalogRelation].tableMeta == metastoreTable1)
// Otherwise, we'll first look up a temporary table with the same name
assert(catalog.lookupRelation(TableIdentifier("tbl1"))
== SubqueryAlias("tbl1", tempTable1))
// Then, if that does not exist, look up the relation in the current database
catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false)
assert(catalog.lookupRelation(TableIdentifier("tbl1")).children.head
- .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1)
+ .asInstanceOf[UnresolvedCatalogRelation].tableMeta == metastoreTable1)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 080f11b769388..e6d09bdae67d7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ClosureCleaner
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -114,7 +115,9 @@ object ReferenceValueClass {
class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)
- implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
+ implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects {
+ ExpressionEncoder()
+ }
// test flat encoders
encodeDecodeTest(false, "primitive boolean")
@@ -355,17 +358,27 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
checkNullable[String](true)
}
- test("null check for map key") {
+ test("null check for map key: String") {
val encoder = ExpressionEncoder[Map[String, Int]]()
val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))
assert(e.getMessage.contains("Cannot use null as map key"))
}
+ test("null check for map key: Integer") {
+ val encoder = ExpressionEncoder[Map[Integer, String]]()
+ val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b"))))
+ assert(e.getMessage.contains("Cannot use null as map key"))
+ }
+
private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
testName: String): Unit = {
- test(s"encode/decode for $testName: $input") {
+ testAndVerifyNotLeakingReflectionObjects(s"encode/decode for $testName: $input") {
val encoder = implicitly[ExpressionEncoder[T]]
+
+ // Make sure encoder is serializable.
+ ClosureCleaner.clean((s: String) => encoder.getClass.getName)
+
val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolveAndBind()
@@ -435,4 +448,28 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
}
}
+
+ /**
+ * Verify the size of scala.reflect.runtime.JavaUniverse.undoLog before and after `func` to
+ * ensure we don't leak Scala reflection garbage.
+ *
+ * @see org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects
+ */
+ private def verifyNotLeakingReflectionObjects[T](func: => T): T = {
+ def undoLogSize: Int = {
+ scala.reflect.runtime.universe
+ .asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.log.size
+ }
+
+ val previousUndoLogSize = undoLogSize
+ val r = func
+ assert(previousUndoLogSize == undoLogSize)
+ r
+ }
+
+ private def testAndVerifyNotLeakingReflectionObjects(testName: String)(testFun: => Any) {
+ test(testName) {
+ verifyNotLeakingReflectionObjects(testFun)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 0d86efda7ea86..0676b948d4298 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -327,4 +327,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}
+
+ test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") {
+ val N = 2000
+ val strings = (1 to N).map(x => "s" * x)
+ val inputsExpr = strings.map(Literal.create(_, StringType))
+
+ checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow)
+ checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index a7ffa884d2286..7837d6529d12b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -827,4 +827,22 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Literal.create(input, from), to), input)
}
+
+ test("SPARK-22500: cast for struct should not generate codes beyond 64KB") {
+ val N = 25
+
+ val fromInner = new StructType(
+ (1 to N).map(i => StructField(s"s$i", DoubleType)).toArray)
+ val toInner = new StructType(
+ (1 to N).map(i => StructField(s"i$i", IntegerType)).toArray)
+ val inputInner = Row.fromSeq((1 to N).map(i => i + 0.5))
+ val outputInner = Row.fromSeq((1 to N))
+ val fromOuter = new StructType(
+ (1 to N).map(i => StructField(s"s$i", fromInner)).toArray)
+ val toOuter = new StructType(
+ (1 to N).map(i => StructField(s"s$i", toInner)).toArray)
+ val inputOuter = Row.fromSeq((1 to N).map(_ => inputInner))
+ val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
+ checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 7ea0bec145481..368f8e1b723f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -324,4 +324,43 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
// should not throw exception
projection(row)
}
+
+ test("SPARK-21720: split large predications into blocks due to JVM code size limit") {
+ val length = 600
+
+ val input = new GenericInternalRow(length)
+ val utf8Str = UTF8String.fromString(s"abc")
+ for (i <- 0 until length) {
+ input.update(i, utf8Str)
+ }
+
+ var exprOr: Expression = Literal(false)
+ for (i <- 0 until length) {
+ exprOr = Or(EqualTo(BoundReference(i, StringType, true), Literal(s"c$i")), exprOr)
+ }
+
+ val planOr = GenerateMutableProjection.generate(Seq(exprOr))
+ val actualOr = planOr(input).toSeq(Seq(exprOr.dataType))
+ assert(actualOr.length == 1)
+ val expectedOr = false
+
+ if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) {
+ fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, expected: $expectedOr")
+ }
+
+ var exprAnd: Expression = Literal(true)
+ for (i <- 0 until length) {
+ exprAnd = And(EqualTo(BoundReference(i, StringType, true), Literal(s"c$i")), exprAnd)
+ }
+
+ val planAnd = GenerateMutableProjection.generate(Seq(exprAnd))
+ val actualAnd = planAnd(input).toSeq(Seq(exprAnd.dataType))
+ assert(actualAnd.length == 1)
+ val expectedAnd = false
+
+ if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) {
+ fail(
+ s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd")
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 9978f35a03810..257c2a3bef974 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -76,6 +76,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
checkEvaluation(DayOfYear(Literal.create(null, DateType)), null)
+
+ checkEvaluation(DayOfYear(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 288)
+ checkEvaluation(DayOfYear(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 277)
checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType)
}
@@ -96,6 +99,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
+ checkEvaluation(Year(Literal(new Date(sdf.parse("1582-01-01 13:10:15").getTime))), 1582)
+ checkEvaluation(Year(Literal(new Date(sdf.parse("1581-12-31 13:10:15").getTime))), 1581)
checkConsistencyBetweenInterpretedAndCodegen(Year, DateType)
}
@@ -116,6 +121,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
+
+ checkEvaluation(Quarter(Literal(new Date(sdf.parse("1582-10-01 13:10:15").getTime))), 4)
+ checkEvaluation(Quarter(Literal(new Date(sdf.parse("1582-09-30 13:10:15").getTime))), 3)
checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType)
}
@@ -125,6 +133,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 4)
checkEvaluation(Month(Cast(Literal(ts), DateType, gmtId)), 11)
+ checkEvaluation(Month(Literal(new Date(sdf.parse("1582-04-28 13:10:15").getTime))), 4)
+ checkEvaluation(Month(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 10)
+ checkEvaluation(Month(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 10)
+
val c = Calendar.getInstance()
(2003 to 2004).foreach { y =>
(0 to 3).foreach { m =>
@@ -146,6 +158,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 8)
checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType, gmtId)), 8)
+ checkEvaluation(DayOfMonth(Literal(new Date(sdf.parse("1582-04-28 13:10:15").getTime))), 28)
+ checkEvaluation(DayOfMonth(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 15)
+ checkEvaluation(DayOfMonth(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 4)
+
val c = Calendar.getInstance()
(1999 to 2000).foreach { y =>
c.set(y, 0, 1, 0, 0, 0)
@@ -160,7 +176,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Seconds") {
assert(Second(Literal.create(null, DateType), gmtId).resolved === false)
- assert(Second(Cast(Literal(d), TimestampType), None).resolved === true)
+ assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true)
checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15)
checkEvaluation(Second(Literal(ts), gmtId), 15)
@@ -186,6 +202,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 15)
checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType, gmtId)), 45)
checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType, gmtId)), 18)
+ checkEvaluation(WeekOfYear(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 40)
+ checkEvaluation(WeekOfYear(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 40)
checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType)
}
@@ -220,7 +238,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Hour") {
assert(Hour(Literal.create(null, DateType), gmtId).resolved === false)
- assert(Hour(Literal(ts), None).resolved === true)
+ assert(Hour(Literal(ts), gmtId).resolved === true)
checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13)
checkEvaluation(Hour(Literal(ts), gmtId), 13)
@@ -246,7 +264,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Minute") {
assert(Minute(Literal.create(null, DateType), gmtId).resolved === false)
- assert(Minute(Literal(ts), None).resolved === true)
+ assert(Minute(Literal(ts), gmtId).resolved === true)
checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(
Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 1ba6dd1c5e8ca..b6399edb68dd6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
protected def checkEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
- val expr: Expression = serializer.deserialize(serializer.serialize(expression))
+ val resolver = ResolveTimeZone(new SQLConf)
+ val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 59fc8eaf73d61..112a4a09728ae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -639,6 +639,35 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval)
}
+ test("SPARK-22284: Compute hash for nested structs") {
+ val M = 80
+ val N = 10
+ val L = M * N
+ val O = 50
+ val seed = 42
+
+ val wideRow = new GenericInternalRow(Seq.tabulate(O)(k =>
+ new GenericInternalRow(Seq.tabulate(M)(j =>
+ new GenericInternalRow(Seq.tabulate(N)(i =>
+ new GenericInternalRow(Array[Any](
+ UTF8String.fromString((k * L + j * N + i).toString))))
+ .toArray[Any])).toArray[Any])).toArray[Any])
+ val inner = new StructType(
+ (0 until N).map(_ => StructField("structOfString", structOfString)).toArray)
+ val outer = new StructType(
+ (0 until M).map(_ => StructField("structOfStructOfString", inner)).toArray)
+ val schema = new StructType(
+ (0 until O).map(_ => StructField("structOfStructOfStructOfString", outer)).toArray)
+ val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val murmur3HashExpr = Murmur3Hash(exprs, 42)
+ val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
+
+ val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow)
+ assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval)
+ }
+
private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val encoder = RowEncoder(inputSchema)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index c5b72235e5db0..53b54de606930 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -21,6 +21,7 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -39,6 +40,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|"fb:testid":"1234"}
|""".stripMargin
+ /* invalid json with leading nulls would trigger java.io.CharConversionException
+ in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */
+ val badJson = "\u0000\u0000\u0000A\u0001AAA"
+
test("$.store.bicycle") {
checkEvaluation(
GetJsonObject(Literal(json), Literal("$.store.bicycle")),
@@ -224,6 +229,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null)
}
+ test("SPARK-16548: character conversion") {
+ checkEvaluation(
+ GetJsonObject(Literal(badJson), Literal("$.a")),
+ null
+ )
+ }
+
test("non foldable literal") {
checkEvaluation(
GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")),
@@ -340,6 +352,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(null, null, null, null, null))
}
+ test("SPARK-16548: json_tuple - invalid json with leading nulls") {
+ checkJsonTuple(
+ JsonTuple(Literal(badJson) :: jsonTupleQuery),
+ InternalRow(null, null, null, null, null))
+ }
+
test("json_tuple - preserve newlines") {
checkJsonTuple(
JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil),
@@ -436,6 +454,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}
+ test("SPARK-20549: from_json bad UTF-8") {
+ val schema = StructType(StructField("a", IntegerType) :: Nil)
+ checkEvaluation(
+ JsonToStructs(schema, Map.empty, Literal(badJson), gmtId),
+ null)
+ }
+
test("from_json with timestamp") {
val schema = StructType(StructField("t", TimestampType) :: Nil)
@@ -566,4 +591,26 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"""{"t":"2015-12-31T16:00:00"}"""
)
}
+
+ test("to_json: verify MapType's value type instead of key type") {
+ // Keys in map are treated as strings when converting to JSON. The type doesn't matter at all.
+ val mapType1 = MapType(CalendarIntervalType, IntegerType)
+ val schema1 = StructType(StructField("a", mapType1) :: Nil)
+ val struct1 = Literal.create(null, schema1)
+ checkEvaluation(
+ StructsToJson(Map.empty, struct1, gmtId),
+ null
+ )
+
+ // The value type must be valid for converting to JSON.
+ val mapType2 = MapType(IntegerType, CalendarIntervalType)
+ val schema2 = StructType(StructField("a", mapType2) :: Nil)
+ val struct2 = Literal.create(null, schema2)
+ intercept[TreeNodeException[_]] {
+ checkEvaluation(
+ StructsToJson(Map.empty, struct2, gmtId),
+ null
+ )
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 6b5bfac94645c..69ada8216515d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -252,6 +252,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
+
+ val doublePi: Double = 3.1415
+ val floatPi: Float = 3.1415f
+ val longLit: Long = 12345678901234567L
+ checkEvaluation(Ceil(doublePi), 4L, EmptyRow)
+ checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow)
+ checkEvaluation(Ceil(longLit), longLit, EmptyRow)
+ checkEvaluation(Ceil(-doublePi), -3L, EmptyRow)
+ checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow)
+ checkEvaluation(Ceil(-longLit), -longLit, EmptyRow)
}
test("floor") {
@@ -262,6 +272,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
+
+ val doublePi: Double = 3.1415
+ val floatPi: Float = 3.1415f
+ val longLit: Long = 12345678901234567L
+ checkEvaluation(Floor(doublePi), 3L, EmptyRow)
+ checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow)
+ checkEvaluation(Floor(longLit), longLit, EmptyRow)
+ checkEvaluation(Floor(-doublePi), -4L, EmptyRow)
+ checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow)
+ checkEvaluation(Floor(-longLit), -longLit, EmptyRow)
}
test("factorial") {
@@ -546,15 +566,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
BigDecimal(3.141593), BigDecimal(3.1415927))
- // round_scale > current_scale would result in precision increase
- // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
+
(0 to 7).foreach { i =>
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
}
(8 to 10).foreach { scale =>
- checkEvaluation(Round(bdPi, scale), null, EmptyRow)
- checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
+ checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow)
+ checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow)
}
DataTypeTestUtils.numericTypes.foreach { dataType =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index 5064a1f63f83d..2ea6aa8ae78ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -133,4 +133,14 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
}
+
+ test("Coalesce should not throw 64kb exception") {
+ val inputs = (1 to 2500).map(x => Literal(s"x_$x"))
+ checkEvaluation(Coalesce(inputs), "x_1")
+ }
+
+ test("AtLeastNNonNulls should not throw 64kb exception") {
+ val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
+ checkEvaluation(AtLeastNNonNulls(1, inputs), true)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
index 190fab5d249bb..d0604b8eb7675 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, LazilyGeneratedOrdering}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, GenerateOrdering, LazilyGeneratedOrdering}
import org.apache.spark.sql.types._
class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -137,4 +137,32 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
// verify that we can support up to 5000 ordering comparisons, which should be sufficient
GenerateOrdering.generate(Array.fill(5000)(sortOrder))
}
+
+ test("SPARK-21344: BinaryType comparison does signed byte array comparison") {
+ val data = Seq(
+ (Array[Byte](1), Array[Byte](-1)),
+ (Array[Byte](1, 1, 1, 1, 1), Array[Byte](1, 1, 1, 1, -1)),
+ (Array[Byte](1, 1, 1, 1, 1, 1, 1, 1, 1), Array[Byte](1, 1, 1, 1, 1, 1, 1, 1, -1))
+ )
+ data.foreach { case (b1, b2) =>
+ val rowOrdering = InterpretedOrdering.forSchema(Seq(BinaryType))
+ val genOrdering = GenerateOrdering.generate(
+ BoundReference(0, BinaryType, nullable = true).asc :: Nil)
+ val rowType = StructType(StructField("b", BinaryType, nullable = true) :: Nil)
+ val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+ val rowB1 = toCatalyst(Row(b1)).asInstanceOf[InternalRow]
+ val rowB2 = toCatalyst(Row(b2)).asInstanceOf[InternalRow]
+ assert(rowOrdering.compare(rowB1, rowB2) < 0)
+ assert(genOrdering.compare(rowB1, rowB2) < 0)
+ }
+ }
+
+ test("SPARK-22591: GenerateOrdering shouldn't change ctx.INPUT_ROW") {
+ val ctx = new CodegenContext()
+ ctx.INPUT_ROW = null
+
+ val schema = new StructType().add("field", FloatType, nullable = true)
+ GenerateOrdering.genComparisons(ctx, schema)
+ assert(ctx.INPUT_ROW == null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 6fe295c3dd936..15ae62477d44f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test(s"3VL $name") {
truthTable.foreach {
case (l, r, answer) =>
- val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
+ val expr = op(NonFoldableLiteral.create(l, BooleanType),
+ NonFoldableLiteral.create(r, BooleanType))
checkEvaluation(expr, answer)
}
}
@@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(false, true) ::
(null, null) :: Nil
notTrueTable.foreach { case (v, answer) =>
- checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
+ checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
}
checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
}
@@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)
test("IN") {
- checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
- checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
- Seq(NonFoldableLiteral(null, IntegerType))), null)
- checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
+ checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
+ Literal(2))), null)
+ checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
+ Seq(NonFoldableLiteral.create(null, IntegerType))), null)
+ checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkEvaluation(In(Literal(1), Seq.empty), false)
- checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
- checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
- checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
+ checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
+ checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
+ true)
+ checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
+ null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkEvaluation(
- And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
+ And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
+ Literal(2)))),
true)
- val ns = NonFoldableLiteral(null, StringType)
+ val ns = NonFoldableLiteral.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
@@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
- val input = inputData.map(NonFoldableLiteral(_, t))
+ val input = inputData.map(NonFoldableLiteral.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
@@ -169,6 +174,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("SPARK-22501: In should not generate codes beyond 64KB") {
+ val N = 3000
+ val sets = (1 to N).map(i => Literal(i.toDouble))
+ checkEvaluation(In(Literal(1.0D), sets), true)
+ }
+
test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
@@ -279,7 +290,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("BinaryComparison: null test") {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
- val nullInt = NonFoldableLiteral(null, IntegerType)
+ val nullInt = NonFoldableLiteral.create(null, IntegerType)
def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
@@ -329,4 +340,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
val infinity = Literal(Double.PositiveInfinity)
checkEvaluation(EqualTo(infinity, infinity), true)
}
+
+ test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") {
+ checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false)
+ checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false)
+ checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false)
+ checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 26978a0482fc7..085d912e64150 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -46,6 +46,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}
+ test("SPARK-22498: Concat should not generate codes beyond 64KB") {
+ val N = 5000
+ val strs = (1 to N).map(x => s"s$x")
+ checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow)
+ }
+
test("concat_ws") {
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
val inputExprs = inputs.map {
@@ -75,6 +81,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}
+ test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") {
+ val N = 5000
+ val sepExpr = Literal.create("#", StringType)
+ val strings1 = (1 to N).map(x => s"s$x")
+ val inputsExpr1 = strings1.map(Literal.create(_, StringType))
+ checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)
+
+ val strings2 = (1 to N).map(x => Seq(s"s$x"))
+ val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
+ checkEvaluation(
+ ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
+ }
+
test("elt") {
def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
checkEvaluation(
@@ -98,6 +117,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
}
+ test("SPARK-22550: Elt should not generate codes beyond 64KB") {
+ val N = 10000
+ val strings = (1 to N).map(x => s"s$x")
+ val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType))
+ checkEvaluation(Elt(args), s"s$N")
+ }
+
test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
@@ -408,6 +434,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}
+ test("SPARK-22603: FormatString should not generate codes beyond 64KB") {
+ val N = 4500
+ val args = (1 to N).map(i => Literal.create(i.toString, StringType))
+ val format = "%s" * N
+ val expected = (1 to N).map(i => i.toString).mkString
+ checkEvaluation(FormatString(Literal(format) +: args: _*), expected)
+ }
+
test("INSTR") {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
new file mode 100644
index 0000000000000..e9d21f8a8ebcd
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BoundReference
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+class GenerateUnsafeProjectionSuite extends SparkFunSuite {
+ test("Test unsafe projection string access pattern") {
+ val dataType = (new StructType).add("a", StringType)
+ val exprs = BoundReference(0, dataType, nullable = true) :: Nil
+ val projection = GenerateUnsafeProjection.generate(exprs)
+ val result = projection.apply(InternalRow(AlwaysNull))
+ assert(!result.isNullAt(0))
+ assert(result.getStruct(0, 1).isNullAt(0))
+ }
+}
+
+object AlwaysNull extends InternalRow {
+ override def numFields: Int = 1
+ override def setNullAt(i: Int): Unit = {}
+ override def copy(): InternalRow = this
+ override def anyNull: Boolean = true
+ override def isNullAt(ordinal: Int): Boolean = true
+ override def update(i: Int, value: Any): Unit = notSupported
+ override def getBoolean(ordinal: Int): Boolean = notSupported
+ override def getByte(ordinal: Int): Byte = notSupported
+ override def getShort(ordinal: Int): Short = notSupported
+ override def getInt(ordinal: Int): Int = notSupported
+ override def getLong(ordinal: Int): Long = notSupported
+ override def getFloat(ordinal: Int): Float = notSupported
+ override def getDouble(ordinal: Int): Double = notSupported
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
+ override def getUTF8String(ordinal: Int): UTF8String = notSupported
+ override def getBinary(ordinal: Int): Array[Byte] = notSupported
+ override def getInterval(ordinal: Int): CalendarInterval = notSupported
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
+ override def getArray(ordinal: Int): ArrayData = notSupported
+ override def getMap(ordinal: Int): MapData = notSupported
+ override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
+ private def notSupported: Nothing = throw new UnsupportedOperationException
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index 9f19745cefd20..75c6beeb32150 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -22,8 +22,10 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Test suite for [[GenerateUnsafeRowJoiner]].
@@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
testConcat(64, 64, fixed)
}
+ test("rows with all empty strings") {
+ val schema = StructType(Seq(
+ StructField("f1", StringType), StructField("f2", StringType)))
+ val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+ InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8))
+ testConcat(schema, row, schema, row)
+ }
+
+ test("rows with all empty int arrays") {
+ val schema = StructType(Seq(
+ StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType))))
+ val emptyIntArray =
+ ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0)
+ val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+ InternalRow(emptyIntArray, emptyIntArray))
+ testConcat(schema, row, schema, row)
+ }
+
+ test("alternating empty and non-empty strings") {
+ val schema = StructType(Seq(
+ StructField("f1", StringType), StructField("f2", StringType)))
+ val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+ InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo")))
+ testConcat(schema, row, schema, row)
+ }
+
test("randomized fix width types") {
for (i <- 0 until 20) {
testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
@@ -66,6 +94,11 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
}
}
+ test("SPARK-22508: GenerateUnsafeRowJoiner.create should not generate codes beyond 64KB") {
+ val N = 3000
+ testConcatOnce(N, N, variable)
+ }
+
private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = {
for (i <- 0 until 10) {
testConcatOnce(numFields1, numFields2, candidateTypes)
@@ -89,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+ testConcat(schema1, row1, schema2, row2)
+ }
+
+ private def testConcat(
+ schema1: StructType,
+ row1: UnsafeRow,
+ schema2: StructType,
+ row2: UnsafeRow) {
// Run the joiner.
val mergedSchema = StructType(schema1 ++ schema2)
val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
- val output = concater.join(row1, row2)
+ val output: UnsafeRow = concater.join(row1, row2)
+
+ // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures
+ // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written
+ // correctly.
+ val expectedOutput: UnsafeRow = {
+ val joinedRowProjection = UnsafeProjection.create(mergedSchema)
+ val joined = new JoinedRow()
+ joinedRowProjection.apply(joined.apply(row1, row2))
+ }
// Test everything equals ...
for (i <- mergedSchema.indices) {
+ val dataType = mergedSchema(i).dataType
if (i < schema1.size) {
assert(output.isNullAt(i) === row1.isNullAt(i))
if (!output.isNullAt(i)) {
- assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
+ assert(output.get(i, dataType) === row1.get(i, dataType))
+ assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
}
} else {
assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
if (!output.isNullAt(i)) {
- assert(output.get(i, mergedSchema(i).dataType) ===
- row2.get(i - schema1.size, mergedSchema(i).dataType))
+ assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType))
+ assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
}
}
}
+
+
+ assert(
+ expectedOutput.getSizeInBytes == output.getSizeInBytes,
+ "output isn't same size in bytes as slow path")
+
+ // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in
+ // case this assertion fails:
+ val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]]
+ .take(output.getSizeInBytes)
+ val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]]
+ .take(expectedOutput.getSizeInBytes)
+
+ val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields())
+ val actualBitset = actualBytes.take(bitsetWidth)
+ val expectedBitset = expectedBytes.take(bitsetWidth)
+ assert(actualBitset === expectedBitset, "bitsets were not equal")
+
+ val fixedLengthSize = expectedOutput.numFields() * 8
+ val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
+ val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
+ if (actualFixedLength !== expectedFixedLength) {
+ actualFixedLength.grouped(8)
+ .zip(expectedFixedLength.grouped(8))
+ .zip(mergedSchema.fields.toIterator)
+ .foreach {
+ case ((actual, expected), field) =>
+ assert(actual === expected, s"Fixed length sections are not equal for field $field")
+ }
+ fail("Fixed length sections were not equal")
+ }
+
+ val variableLengthStart = bitsetWidth + fixedLengthSize
+ val actualVariableLength = actualBytes.drop(variableLengthStart)
+ val expectedVariableLength = expectedBytes.drop(variableLengthStart)
+ assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal")
+
+ assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index 935bff7cef2e8..c275f997ba6e9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.Row
class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
@@ -42,6 +43,16 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
+ val testRelationWithData = LocalRelation.fromExternalRows(
+ testRelation.output, Seq(Row(1, 2, 3, "abc"))
+ )
+
+ private def checkCondition(input: Expression, expected: LogicalPlan): Unit = {
+ val plan = testRelationWithData.where(input).analyze
+ val actual = Optimize.execute(plan)
+ comparePlans(actual, expected)
+ }
+
private def checkCondition(input: Expression, expected: Expression): Unit = {
val plan = testRelation.where(input).analyze
val actual = Optimize.execute(plan)
@@ -160,4 +171,12 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
comparePlans(actual, expected)
}
+
+ test("Complementation Laws") {
+ checkCondition('a && !'a, testRelation)
+ checkCondition(!'a && 'a, testRelation)
+
+ checkCondition('a || !'a, testRelationWithData)
+ checkCondition(!'a || 'a, testRelationWithData)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 589607e3ad5cb..a0a0daea7d075 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -321,15 +321,14 @@ class ColumnPruningSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input),
- BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze
+ ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
- BroadcastHint(
- Project(Seq($"y.key"), SubqueryAlias("y", input))),
+ ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
Inner, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala
new file mode 100644
index 0000000000000..d4f37e2a5e877
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types._
+
+class EliminateMapObjectsSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = {
+ Batch("EliminateMapObjects", FixedPoint(50),
+ NullPropagation(conf),
+ SimplifyCasts,
+ EliminateMapObjects) :: Nil
+ }
+ }
+
+ implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
+ implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()
+
+ test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
+ val intObjType = ObjectType(classOf[Array[Int]])
+ val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
+ val intQuery = intInput.deserialize[Array[Int]].analyze
+ val intOptimized = Optimize.execute(intQuery)
+ val intExpected = DeserializeToObject(
+ Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
+ AttributeReference("obj", intObjType, true)(), intInput)
+ comparePlans(intOptimized, intExpected)
+
+ val doubleObjType = ObjectType(classOf[Array[Double]])
+ val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
+ val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
+ val doubleOptimized = Optimize.execute(doubleQuery)
+ val doubleExpected = DeserializeToObject(
+ Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
+ AttributeReference("obj", doubleObjType, true)(), doubleInput)
+ comparePlans(doubleOptimized, doubleExpected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 950aa2379517e..4d41354185a96 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -797,13 +797,26 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("aggregate: don't push filters if the aggregate has no grouping expressions") {
+ val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty)
+ .select('a, 'b)
+ .groupBy()(count(1))
+ .where(false)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = originalQuery.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("broadcast hint") {
- val originalQuery = BroadcastHint(testRelation)
+ val originalQuery = ResolvedHint(testRelation)
.where('a === 2L && 'b + Rand(10).as("rnd") === 3)
val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
+ val correctAnswer = ResolvedHint(testRelation.where('a === 2L))
.where('b + Rand(10).as("rnd") === 3)
.analyze
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
index c8fe37462726a..9a4bcdb011435 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
@@ -33,7 +33,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints(conf),
- CombineFilters) :: Nil
+ CombineFilters,
+ BooleanSimplification) :: Nil
}
object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] {
@@ -172,7 +173,12 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)
- val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+ // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through
+ // the constraint inference procedure.
+ val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
+ // We hide an `Alias` inside the child's child's expressions, to cover the situation reported
+ // in [SPARK-20700].
+ .select('int_col, 'd, 'a).as("t")
.join(t2, Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
@@ -180,22 +186,18 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
- && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
- && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))
- && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
- && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
+ && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))
+ && Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b)))
+ && 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
- && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))
- && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
- .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+ && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)))
+ .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
+ .select('int_col, 'd, 'a).as("t")
.join(t2
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
- && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
- && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner,
- Some("t.a".attr === "t2.a".attr
- && "t.d".attr === "t2.a".attr
- && "t.int_col".attr === "t2.a".attr
- && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr))
+ && 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
+ Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
+ && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index a43d78c7bd447..105407d43bf39 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -129,14 +129,14 @@ class JoinOptimizationSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input),
- BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze
+ ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
- BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
+ ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
Cross, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
index fdde89d079bc0..50398788c605c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.catalyst.optimizer
-/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{BooleanType, StringType}
class LikeSimplificationSuite extends PlanTest {
@@ -100,4 +100,10 @@ class LikeSimplificationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("null pattern") {
+ val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
index 2885fd6841e9d..540b50d654e22 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
@@ -111,35 +111,34 @@ class LimitPushdownSuite extends PlanTest {
test("full outer join where neither side is limited and both sides have same statistics") {
assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes)
- val originalQuery = x.join(y, FullOuter).limit(1)
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze
- comparePlans(optimized, correctAnswer)
+ val originalQuery = x.join(y, FullOuter).limit(1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ // No pushdown for FULL OUTER JOINS.
+ comparePlans(optimized, originalQuery)
}
test("full outer join where neither side is limited and left side has larger statistics") {
val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x)
assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes)
- val originalQuery = xBig.join(y, FullOuter).limit(1)
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze
- comparePlans(optimized, correctAnswer)
+ val originalQuery = xBig.join(y, FullOuter).limit(1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ // No pushdown for FULL OUTER JOINS.
+ comparePlans(optimized, originalQuery)
}
test("full outer join where neither side is limited and right side has larger statistics") {
val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y)
assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes)
- val originalQuery = x.join(yBig, FullOuter).limit(1)
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze
- comparePlans(optimized, correctAnswer)
+ val originalQuery = x.join(yBig, FullOuter).limit(1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ // No pushdown for FULL OUTER JOINS.
+ comparePlans(optimized, originalQuery)
}
test("full outer join where both sides are limited") {
- val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1)
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze
- comparePlans(optimized, correctAnswer)
+ val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ // No pushdown for FULL OUTER JOINS.
+ comparePlans(optimized, originalQuery)
}
}
-
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index d8937321ecb98..f12f0f5eb4cd4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -166,7 +166,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan)
optimizedPlan match {
case Filter(cond, _)
- if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
+ if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index c261a6091d476..38dff4733f714 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -142,7 +142,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer.analyze)
}
- test("propagate empty relation through Aggregate without aggregate function") {
+ test("propagate empty relation through Aggregate with grouping expressions") {
val query = testRelation1
.where(false)
.groupBy('a)('a, ('a + 1).as('x))
@@ -153,13 +153,13 @@ class PropagateEmptyRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("don't propagate empty relation through Aggregate with aggregate function") {
+ test("don't propagate empty relation through Aggregate without grouping expressions") {
val query = testRelation1
.where(false)
- .groupBy('a)(count('a))
+ .groupBy()()
val optimized = Optimize.execute(query.analyze)
- val correctAnswer = LocalRelation('a.int).groupBy('a)(count('a)).analyze
+ val correctAnswer = LocalRelation('a.int).groupBy()().analyze
comparePlans(optimized, correctAnswer)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 3964fa3924b24..4490523369006 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -30,7 +30,7 @@ class DataTypeParserSuite extends SparkFunSuite {
}
}
- def intercept(sql: String): Unit =
+ def intercept(sql: String): ParseException =
intercept[ParseException](CatalystSqlParser.parseDataType(sql))
def unsupported(dataTypeString: String): Unit = {
@@ -118,6 +118,11 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("struct")
+ test("Do not print empty parentheses for no params") {
+ assert(intercept("unkwon").getMessage.contains("unkwon is not supported"))
+ assert(intercept("unkwon(1,2,3)").getMessage.contains("unkwon(1,2,3) is not supported"))
+ }
+
// DataType parser accepts certain reserved keywords.
checkDataType(
"Struct",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index e7f3b64a71130..5ae8d2e849828 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest {
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
- def assertEqual(sqlCommand: String, e: Expression): Unit = {
- compareExpressions(parseExpression(sqlCommand), e)
+ val defaultParser = CatalystSqlParser
+
+ def assertEqual(
+ sqlCommand: String,
+ e: Expression,
+ parser: ParserInterface = defaultParser): Unit = {
+ compareExpressions(parser.parseExpression(sqlCommand), e)
}
def intercept(sqlCommand: String, messages: String*): Unit = {
- val e = intercept[ParseException](parseExpression(sqlCommand))
+ val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand))
messages.foreach { message =>
assert(e.message.contains(message))
}
@@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest {
test("long binary logical expressions") {
def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
- val e = parseExpression(sql)
+ val e = defaultParser.parseExpression(sql)
assert(e.collect { case _: EqualTo => true }.size === 1000)
assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
}
@@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
}
+ test("like expressions with ESCAPED_STRING_LITERALS = true") {
+ val conf = new SQLConf()
+ conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true")
+ val parser = new CatalystSqlParser(conf)
+ assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
+ assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
+ assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
+ }
+
test("is null expressions") {
assertEqual("a is null", 'a.isNull)
assertEqual("a is not null", 'a.isNotNull)
@@ -211,7 +226,7 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b))
assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b))
assertEqual("`select`(all a, b)", 'select.function('a, 'b))
- assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e))
+ intercept("foo(a x)", "extraneous input 'x'")
}
test("window function expressions") {
@@ -310,7 +325,9 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("a.b", UnresolvedAttribute("a.b"))
assertEqual("`select`.b", UnresolvedAttribute("select.b"))
assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
- assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b"))
+ assertEqual(
+ "struct(a, b).b",
+ namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b"))
}
test("reference") {
@@ -413,38 +430,87 @@ class ExpressionParserSuite extends PlanTest {
}
test("strings") {
- // Single Strings.
- assertEqual("\"hello\"", "hello")
- assertEqual("'hello'", "hello")
-
- // Multi-Strings.
- assertEqual("\"hello\" 'world'", "helloworld")
- assertEqual("'hello' \" \" 'world'", "hello world")
-
- // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
- // regular '%'; to get the correct result you need to add another escaped '\'.
- // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
- assertEqual("'pattern%'", "pattern%")
- assertEqual("'no-pattern\\%'", "no-pattern\\%")
- assertEqual("'pattern\\\\%'", "pattern\\%")
- assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
-
- // Escaped characters.
- // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
- assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
- assertEqual("'\\''", "\'") // Single quote
- assertEqual("'\\\"'", "\"") // Double quote
- assertEqual("'\\b'", "\b") // Backspace
- assertEqual("'\\n'", "\n") // Newline
- assertEqual("'\\r'", "\r") // Carriage return
- assertEqual("'\\t'", "\t") // Tab character
- assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
-
- // Octals
- assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
-
- // Unicode
- assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
+ Seq(true, false).foreach { escape =>
+ val conf = new SQLConf()
+ conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString)
+ val parser = new CatalystSqlParser(conf)
+
+ // tests that have same result whatever the conf is
+ // Single Strings.
+ assertEqual("\"hello\"", "hello", parser)
+ assertEqual("'hello'", "hello", parser)
+
+ // Multi-Strings.
+ assertEqual("\"hello\" 'world'", "helloworld", parser)
+ assertEqual("'hello' \" \" 'world'", "hello world", parser)
+
+ // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
+ // regular '%'; to get the correct result you need to add another escaped '\'.
+ // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
+ assertEqual("'pattern%'", "pattern%", parser)
+ assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)
+
+ // tests that have different result regarding the conf
+ if (escape) {
+ // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
+ // Spark 1.6 behavior.
+
+ // 'LIKE' string literals.
+ assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
+ assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)
+
+ // Escaped characters.
+ // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work
+ // when ESCAPED_STRING_LITERALS is enabled.
+ // It is parsed literally.
+ assertEqual("'\\0'", "\\0", parser)
+
+ // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled.
+ val e = intercept[ParseException](parser.parseExpression("'\''"))
+ assert(e.message.contains("extraneous input '''"))
+
+ // The unescape special characters (e.g., "\\t") for 2.0+ don't work
+ // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally.
+ assertEqual("'\\\"'", "\\\"", parser) // Double quote
+ assertEqual("'\\b'", "\\b", parser) // Backspace
+ assertEqual("'\\n'", "\\n", parser) // Newline
+ assertEqual("'\\r'", "\\r", parser) // Carriage return
+ assertEqual("'\\t'", "\\t", parser) // Tab character
+
+ // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled.
+ // They are parsed literally.
+ assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser)
+ // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled.
+ // They are parsed literally.
+ assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'",
+ "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser)
+ } else {
+ // Default behavior
+
+ // 'LIKE' string literals.
+ assertEqual("'pattern\\\\%'", "pattern\\%", parser)
+ assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)
+
+ // Escaped characters.
+ // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
+ assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
+ assertEqual("'\\''", "\'", parser) // Single quote
+ assertEqual("'\\\"'", "\"", parser) // Double quote
+ assertEqual("'\\b'", "\b", parser) // Backspace
+ assertEqual("'\\n'", "\n", parser) // Newline
+ assertEqual("'\\r'", "\r", parser) // Carriage return
+ assertEqual("'\\t'", "\t", parser) // Tab character
+ assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)
+
+ // Octals
+ assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)
+
+ // Unicode
+ assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
+ parser)
+ }
+
+ }
}
test("intervals") {
@@ -524,11 +590,6 @@ class ExpressionParserSuite extends PlanTest {
intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'")
}
- test("current date/timestamp braceless expressions") {
- assertEqual("current_date", CurrentDate())
- assertEqual("current_timestamp", CurrentTimestamp())
- }
-
test("SPARK-17364, fully qualified column name which starts with number") {
assertEqual("123_", UnresolvedAttribute("123_"))
assertEqual("1a.123_", UnresolvedAttribute("1a.123_"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 411777d6e85a2..950f152b94b4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -176,14 +176,14 @@ class PlanParserSuite extends PlanTest {
def insert(
partition: Map[String, Option[String]],
overwrite: Boolean = false,
- ifNotExists: Boolean = false): LogicalPlan =
- InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
+ ifPartitionNotExists: Boolean = false): LogicalPlan =
+ InsertIntoTable(table("s"), partition, plan, overwrite, ifPartitionNotExists)
// Single inserts
assertEqual(s"insert overwrite table s $sql",
insert(Map.empty, overwrite = true))
assertEqual(s"insert overwrite table s partition (e = 1) if not exists $sql",
- insert(Map("e" -> Option("1")), overwrite = true, ifNotExists = true))
+ insert(Map("e" -> Option("1")), overwrite = true, ifPartitionNotExists = true))
assertEqual(s"insert into s $sql",
insert(Map.empty))
assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
@@ -193,9 +193,9 @@ class PlanParserSuite extends PlanTest {
val plan2 = table("t").where('x > 5).select(star())
assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
InsertIntoTable(
- table("s"), Map.empty, plan.limit(1), false, ifNotExists = false).union(
+ table("s"), Map.empty, plan.limit(1), false, ifPartitionNotExists = false).union(
InsertIntoTable(
- table("u"), Map.empty, plan2, false, ifNotExists = false)))
+ table("u"), Map.empty, plan2, false, ifPartitionNotExists = false)))
}
test ("insert with if not exists") {
@@ -223,6 +223,12 @@ class PlanParserSuite extends PlanTest {
assertEqual(s"$sql grouping sets((a, b), (a), ())",
GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"),
Seq('a, 'b, 'sum.function('c).as("c"))))
+
+ val m = intercept[ParseException] {
+ parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b")
+ }.getMessage
+ assert(m.contains("extraneous input 'b'"))
+
}
test("limit") {
@@ -496,46 +502,109 @@ class PlanParserSuite extends PlanTest {
val m = intercept[ParseException] {
parsePlan("SELECT /*+ HINT() */ * FROM t")
}.getMessage
- assert(m.contains("no viable alternative at input"))
-
- // Hive compatibility: No database.
- val m2 = intercept[ParseException] {
- parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t")
- }.getMessage
- assert(m2.contains("mismatched input '.' expecting {')', ','}"))
+ assert(m.contains("mismatched input"))
// Disallow space as the delimiter.
val m3 = intercept[ParseException] {
parsePlan("SELECT /*+ INDEX(a b c) */ * from default.t")
}.getMessage
- assert(m3.contains("mismatched input 'b' expecting {')', ','}"))
+ assert(m3.contains("mismatched input 'b' expecting"))
comparePlans(
parsePlan("SELECT /*+ HINT */ * FROM t"),
- Hint("HINT", Seq.empty, table("t").select(star())))
+ UnresolvedHint("HINT", Seq.empty, table("t").select(star())))
comparePlans(
parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"),
- Hint("BROADCASTJOIN", Seq("u"), table("t").select(star())))
+ UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star())))
comparePlans(
parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"),
- Hint("MAPJOIN", Seq("u"), table("t").select(star())))
+ UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star())))
comparePlans(
parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"),
- Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star())))
+ UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star())))
comparePlans(
parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"),
- Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star())))
+ UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star())))
comparePlans(
parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"),
- Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star())))
+ UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")),
+ table("default.t").select(star())))
comparePlans(
parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
- Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))
+ UnresolvedHint("MAPJOIN", Seq($"t"),
+ table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))
+ }
+
+ test("SPARK-20854: select hint syntax with expressions") {
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1(a, array(1, 2, 3)) */ * from t"),
+ UnresolvedHint("HINT1", Seq($"a",
+ UnresolvedFunction("array", Literal(1) :: Literal(2) :: Literal(3) :: Nil, false)),
+ table("t").select(star())
+ )
+ )
+
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1(a, 5, 'a', b) */ * from t"),
+ UnresolvedHint("HINT1", Seq($"a", Literal(5), Literal("a"), $"b"),
+ table("t").select(star())
+ )
+ )
+
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1('a', (b, c), (1, 2)) */ * from t"),
+ UnresolvedHint("HINT1",
+ Seq(Literal("a"),
+ CreateStruct($"b" :: $"c" :: Nil),
+ CreateStruct(Literal(1) :: Literal(2) :: Nil)),
+ table("t").select(star())
+ )
+ )
+ }
+
+ test("SPARK-20854: multiple hints") {
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1(a, 1) hint2(b, 2) */ * from t"),
+ UnresolvedHint("HINT1", Seq($"a", Literal(1)),
+ UnresolvedHint("hint2", Seq($"b", Literal(2)),
+ table("t").select(star())
+ )
+ )
+ )
+
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1(a, 1),hint2(b, 2) */ * from t"),
+ UnresolvedHint("HINT1", Seq($"a", Literal(1)),
+ UnresolvedHint("hint2", Seq($"b", Literal(2)),
+ table("t").select(star())
+ )
+ )
+ )
+
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1(a, 1) */ /*+ hint2(b, 2) */ * from t"),
+ UnresolvedHint("HINT1", Seq($"a", Literal(1)),
+ UnresolvedHint("hint2", Seq($"b", Literal(2)),
+ table("t").select(star())
+ )
+ )
+ )
+
+ comparePlans(
+ parsePlan("SELECT /*+ HINT1(a, 1), hint2(b, 2) */ /*+ hint3(c, 3) */ * from t"),
+ UnresolvedHint("HINT1", Seq($"a", Literal(1)),
+ UnresolvedHint("hint2", Seq($"b", Literal(2)),
+ UnresolvedHint("hint3", Seq($"c", Literal(3)),
+ table("t").select(star())
+ )
+ )
+ )
+ )
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
index 467f76193cfc5..7c8ed78a49116 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union}
import org.apache.spark.sql.catalyst.util._
/**
@@ -66,4 +66,10 @@ class SameResultSuite extends SparkFunSuite {
assertSameResult(Union(Seq(testRelation, testRelation2)),
Union(Seq(testRelation2, testRelation)))
}
+
+ test("hint") {
+ val df1 = testRelation.join(ResolvedHint(testRelation))
+ val df2 = testRelation.join(testRelation)
+ assertSameResult(df1, df2)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index b06871f96f0d8..2afea6dd3d37c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -37,19 +37,20 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
test("BroadcastHint estimation") {
val filter = Filter(Literal(true), plan)
- val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
+ val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4),
rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
- val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
+ val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4))
checkStats(
filter,
expectedStatsCboOn = filterStatsCboOn,
expectedStatsCboOff = filterStatsCboOff)
- val broadcastHint = BroadcastHint(filter)
+ val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true)))
checkStats(
broadcastHint,
- expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
- expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
+ expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))),
+ expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true)))
+ )
}
test("limit estimation: limit < child's rowCount") {
@@ -94,15 +95,13 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 40,
rowCount = Some(10),
attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
- isBroadcastable = false)
+ AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))))
val expectedCboStats =
Statistics(
sizeInBytes = 4,
rowCount = Some(1),
attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
- isBroadcastable = false)
+ AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))))
val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
checkStats(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index a28447840ae09..2fa53a6466ef2 100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -150,7 +150,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 3)),
expectedRowCount = 3)
}
@@ -158,7 +158,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 8)),
expectedRowCount = 8)
}
@@ -174,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 8)),
expectedRowCount = 8)
}
@@ -205,7 +205,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint < 3") {
validateEstimatedStats(
Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 3)
}
@@ -221,7 +221,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint <= 3") {
validateEstimatedStats(
Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 3)
}
@@ -229,7 +229,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint > 6") {
validateEstimatedStats(
Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 5)
}
@@ -245,7 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint >= 6") {
validateEstimatedStats(
Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 5)
}
@@ -279,7 +279,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 4)
}
@@ -288,8 +288,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 2)),
expectedRowCount = 2)
}
@@ -297,7 +296,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 6)),
expectedRowCount = 6)
}
@@ -305,7 +304,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 5)),
expectedRowCount = 5)
}
@@ -321,7 +320,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8"))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)),
- Seq(attrInt -> colStatInt, attrString -> colStatString),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 9),
+ attrString -> colStatString.copy(distinctCount = 9)),
expectedRowCount = 9)
}
@@ -336,8 +336,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint NOT IN (3, 4, 5)") {
validateEstimatedStats(
Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = 7)),
expectedRowCount = 7)
}
@@ -380,7 +379,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(LessThan(attrDate, Literal(d20170103, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
- Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
+ Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 3)
}
@@ -421,7 +420,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cdouble < 3.0") {
validateEstimatedStats(
Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)),
- Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
+ Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0),
nullCount = 0, avgLen = 8, maxLen = 8)),
expectedRowCount = 3)
}
@@ -487,9 +486,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 4)
}
@@ -498,9 +497,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10),
+ attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 4)
}
@@ -509,9 +508,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16),
+ attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 4)
}
@@ -531,9 +530,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10),
+ attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)),
expectedRowCount = 4)
}
@@ -565,6 +564,20 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
expectedRowCount = 0)
}
+ test("update ndv for columns based on overall selectivity") {
+ // filter condition: cint > 3 AND cint4 <= 6
+ val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt4, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)),
+ Seq(
+ attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attrString -> colStatString.copy(distinctCount = 5)),
+ expectedRowCount = 5)
+ }
+
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
StatsTestPlan(
outputList = outList,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 37e3dfabd0b21..06ef7bcee0d84 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil
}
-case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
+case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
override def children: Seq[Expression] = map.values.toSeq
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}
+case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
+ nonSons: Seq[(Expression, Expression)]) extends Unevaluable {
+ override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
+ override def nullable: Boolean = true
+ override def dataType: NullType = NullType
+ override lazy val resolved = true
+}
+
case class JsonTestTreeNode(arg: Any) extends LeafNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
}
@@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite {
assert(actual === Dummy(None))
}
+ test("mapChildren should only works on children") {
+ val children = Seq((Literal(1), Literal(2)))
+ val nonChildren = Seq((Literal(3), Literal(4)))
+ val before = SeqTupleExpression(children, nonChildren)
+ val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
+ val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren)
+
+ val actual = before mapChildren toZero
+ assert(actual === expect)
+ }
+
test("preserves origin") {
CurrentOrigin.setPosition(1, 1)
val add = Add(Literal(1), Literal(1))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index 9799817494f15..deaf2f9d2bc21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -34,6 +34,22 @@ class DateTimeUtilsSuite extends SparkFunSuite {
((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt
}
+ test("nanoseconds truncation") {
+ def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) {
+ val parsedTimestampOp = DateTimeUtils.stringToTimestamp(UTF8String.fromString(originalTime))
+ assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly")
+ assert(DateTimeUtils.timestampToString(parsedTimestampOp.get) === expectedParsedTime)
+ }
+
+ checkStringToTimestamp("2015-01-02 00:00:00.123456789", "2015-01-02 00:00:00.123456")
+ checkStringToTimestamp("2015-01-02 00:00:00.100000009", "2015-01-02 00:00:00.1")
+ checkStringToTimestamp("2015-01-02 00:00:00.000050000", "2015-01-02 00:00:00.00005")
+ checkStringToTimestamp("2015-01-02 00:00:00.12005", "2015-01-02 00:00:00.12005")
+ checkStringToTimestamp("2015-01-02 00:00:00.100", "2015-01-02 00:00:00.1")
+ checkStringToTimestamp("2015-01-02 00:00:00.000456789", "2015-01-02 00:00:00.000456")
+ checkStringToTimestamp("1950-01-02 00:00:00.000456789", "1950-01-02 00:00:00.000456")
+ }
+
test("timestamp and us") {
val now = new Timestamp(System.currentTimeMillis())
now.setNanos(1000)
@@ -564,18 +580,18 @@ class DateTimeUtilsSuite extends SparkFunSuite {
assert(daysToMillis(16800, TimeZoneGMT) === c.getTimeInMillis)
// There are some days are skipped entirely in some timezone, skip them here.
- val skipped_days = Map[String, Int](
- "Kwajalein" -> 8632,
- "Pacific/Apia" -> 15338,
- "Pacific/Enderbury" -> 9131,
- "Pacific/Fakaofo" -> 15338,
- "Pacific/Kiritimati" -> 9131,
- "Pacific/Kwajalein" -> 8632,
- "MIT" -> 15338)
+ val skipped_days = Map[String, Set[Int]](
+ "Kwajalein" -> Set(8632),
+ "Pacific/Apia" -> Set(15338),
+ "Pacific/Enderbury" -> Set(9130, 9131),
+ "Pacific/Fakaofo" -> Set(15338),
+ "Pacific/Kiritimati" -> Set(9130, 9131),
+ "Pacific/Kwajalein" -> Set(8632),
+ "MIT" -> Set(15338))
for (tz <- DateTimeTestUtils.ALL_TIMEZONES) {
- val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue)
+ val skipped = skipped_days.getOrElse(tz.getID, Set.empty)
(-20000 to 20000).foreach { d =>
- if (d != skipped) {
+ if (!skipped.contains(d)) {
assert(millisToDays(daysToMillis(d, tz), tz) === d,
s"Round trip of ${d} did not work in tz ${tz}")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 714883a4099cf..f4cdb7058ab1e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
test("creating decimals") {
checkDecimal(new Decimal(), "0", 1, 0)
+ checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2)
+ checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1)
+ checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2)
+ checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1)
+ checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0)
+ checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1)
+ checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2)
+ checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1)
+ checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2)
+ checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1)
checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1)
@@ -203,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(d.changePrecision(10, 0, mode))
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
- val copy = d.toPrecision(10, 0, mode).orNull
+ val copy = d.toPrecision(10, 0, mode)
assert(copy !== null)
assert(d.ne(copy))
assert(d === copy)
@@ -212,4 +222,10 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
}
}
}
+
+ test("SPARK-20341: support BigInt's value does not fit in long value range") {
+ val bigInt = scala.math.BigInt("9223372036854775808")
+ val decimal = Decimal.apply(bigInt)
+ assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808")
+ }
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index b203f31a76f03..32dab3d9e861d 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../../pom.xml
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index cd521c52d1b21..4299cc8eccda2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -20,7 +20,7 @@
import java.io.IOException;
import org.apache.spark.SparkEnv;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -84,7 +84,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
- * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param taskContext the current task context.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -93,7 +93,7 @@ public UnsafeFixedWidthAggregationMap(
InternalRow emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
- TaskMemoryManager taskMemoryManager,
+ TaskContext taskContext,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
@@ -101,13 +101,20 @@ public UnsafeFixedWidthAggregationMap(
this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
- this.map =
- new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+ this.map = new BytesToBytesMap(
+ taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+
+ // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
+ // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
+ // does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
+ taskContext.addTaskCompletionListener(context -> {
+ free();
+ });
}
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index ee5bcfd02c79e..7d67b87ed915d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -33,6 +33,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.unsafe.sort.*;
@@ -96,15 +97,29 @@ public UnsafeKVExternalSorter(
numElementsForSpillThreshold,
canUseRadixSort);
} else {
- // The array will be used to do in-place sort, which require half of the space to be empty.
- // Note: each record in the map takes two entries in the array, one is record pointer,
- // another is the key prefix.
- assert(map.numKeys() * 2 <= map.getArray().size() / 2);
- // During spilling, the array in map will not be used, so we can borrow that and use it
- // as the underlying array for in-memory sorter (it's always large enough).
- // Since we will not grow the array, it's fine to pass `null` as consumer.
+ // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow
+ // that and use it as the pointer array for `UnsafeInMemorySorter`.
+ LongArray pointerArray = map.getArray();
+ // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but
+ // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since
+ // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer
+ // array can hold all the entries in `BytesToBytesMap`.
+ // The pointer array will be used to do in-place sort, which requires half of the space to be
+ // empty. Note: each record in the map takes two entries in the pointer array, one is record
+ // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`.
+ // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key,
+ // so that we can always reuse the pointer array.
+ if (map.numValues() > pointerArray.size() / 4) {
+ // Here we ask the map to allocate memory, so that the memory manager won't ask the map
+ // to spill, if the memory is not enough.
+ pointerArray = map.allocateArray(map.numValues() * 4L);
+ }
+
+ // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed
+ // to be large enough, it's fine to pass `null` as consumer because we won't allocate more
+ // memory.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
- null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
+ null, taskMemoryManager, recordComparator, prefixComparator, pointerArray,
canUseRadixSort);
// We cannot use the destructive iterator here because we are reusing the existing memory
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index eb97118872ea1..5a810cae1e184 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -66,7 +66,6 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import org.apache.spark.util.AccumulatorV2;
-import org.apache.spark.util.LongAccumulator;
/**
* Base class for custom RecordReaders for Parquet that directly materialize to `T`.
@@ -153,14 +152,16 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
}
// For test purpose.
- // If the predefined accumulator exists, the row group number to read will be updated
- // to the accumulator. So we can check if the row groups are filtered or not in test case.
+ // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read
+ // will be updated to the accumulator. So we can check if the row groups are filtered or not
+ // in test case.
TaskContext taskContext = TaskContext$.MODULE$.get();
if (taskContext != null) {
- Option> accu = taskContext.taskMetrics()
- .lookForAccumulatorByName("numRowGroups");
- if (accu.isDefined()) {
- ((LongAccumulator)accu.get()).add((long)blocks.size());
+ Option> accu = taskContext.taskMetrics().externalAccums().lastOption();
+ if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
+ @SuppressWarnings("unchecked")
+ AccumulatorV2 intAccum = (AccumulatorV2) accu.get();
+ intAccum.add(blocks.size());
}
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index 354c878aca000..25524f370bb0d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -180,7 +180,7 @@ public Object[] array() {
@Override
public boolean getBoolean(int ordinal) {
- throw new UnsupportedOperationException();
+ return data.getBoolean(offset + ordinal);
}
@Override
@@ -188,7 +188,7 @@ public boolean getBoolean(int ordinal) {
@Override
public short getShort(int ordinal) {
- throw new UnsupportedOperationException();
+ return data.getShort(offset + ordinal);
}
@Override
@@ -199,7 +199,7 @@ public short getShort(int ordinal) {
@Override
public float getFloat(int ordinal) {
- throw new UnsupportedOperationException();
+ return data.getFloat(offset + ordinal);
}
@Override
@@ -282,7 +282,21 @@ public void reset() {
* Cleans up memory for this column. The column is not usable after this.
* TODO: this should probably have ref-counted semantics.
*/
- public abstract void close();
+ public void close() {
+ if (childColumns != null) {
+ for (int i = 0; i < childColumns.length; i++) {
+ if (childColumns[i] != null) {
+ childColumns[i].close();
+ childColumns[i] = null;
+ }
+ }
+ }
+ if (dictionaryIds != null) {
+ dictionaryIds.close();
+ dictionaryIds = null;
+ }
+ dictionary = null;
+ }
public void reserve(int requiredCapacity) {
if (requiredCapacity > capacity) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index a6ce4c2edc232..8b7b0e655b31d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -198,21 +198,25 @@ public boolean anyNull() {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getDecimal(rowId, precision, scale);
}
@Override
public UTF8String getUTF8String(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getUTF8String(rowId);
}
@Override
public byte[] getBinary(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getBinary(rowId);
}
@Override
public CalendarInterval getInterval(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
@@ -220,11 +224,13 @@ public CalendarInterval getInterval(int ordinal) {
@Override
public InternalRow getStruct(int ordinal, int numFields) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getStruct(rowId);
}
@Override
public ArrayData getArray(int ordinal) {
+ if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getArray(rowId);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index e988c0722bd72..53a9b30bee7ab 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -64,6 +64,7 @@ public long nullsNativeAddress() {
@Override
public void close() {
+ super.close();
Platform.freeMemory(nulls);
Platform.freeMemory(data);
Platform.freeMemory(lengthData);
@@ -436,28 +437,29 @@ public void loadBytes(ColumnVector.Array array) {
// Split out the slow path.
@Override
protected void reserveInternal(int newCapacity) {
+ int oldCapacity = (nulls == 0L) ? 0 : capacity;
if (this.resultArray != null) {
this.lengthData =
- Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4);
+ Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4);
this.offsetData =
- Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4);
+ Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4);
} else if (type instanceof ByteType || type instanceof BooleanType) {
- this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
+ this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity);
} else if (type instanceof ShortType) {
- this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
+ this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2);
} else if (type instanceof IntegerType || type instanceof FloatType ||
type instanceof DateType || DecimalType.is32BitDecimalType(type)) {
- this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
+ this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) {
- this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8);
+ this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8);
} else if (resultStruct != null) {
// Nothing to store.
} else {
throw new RuntimeException("Unhandled " + type);
}
- this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity);
- Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended);
+ this.nulls = Platform.reallocateMemory(nulls, oldCapacity, newCapacity);
+ Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity);
capacity = newCapacity;
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 9b410bacff5df..eb06c88d8e31c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -68,6 +68,16 @@ public long nullsNativeAddress() {
@Override
public void close() {
+ super.close();
+ nulls = null;
+ byteData = null;
+ shortData = null;
+ intData = null;
+ longData = null;
+ floatData = null;
+ doubleData = null;
+ arrayLengths = null;
+ arrayOffsets = null;
}
//
@@ -410,53 +420,53 @@ protected void reserveInternal(int newCapacity) {
int[] newLengths = new int[newCapacity];
int[] newOffsets = new int[newCapacity];
if (this.arrayLengths != null) {
- System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended);
- System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended);
+ System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity);
+ System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity);
}
arrayLengths = newLengths;
arrayOffsets = newOffsets;
} else if (type instanceof BooleanType) {
if (byteData == null || byteData.length < newCapacity) {
byte[] newData = new byte[newCapacity];
- if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity);
byteData = newData;
}
} else if (type instanceof ByteType) {
if (byteData == null || byteData.length < newCapacity) {
byte[] newData = new byte[newCapacity];
- if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity);
byteData = newData;
}
} else if (type instanceof ShortType) {
if (shortData == null || shortData.length < newCapacity) {
short[] newData = new short[newCapacity];
- if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
+ if (shortData != null) System.arraycopy(shortData, 0, newData, 0, capacity);
shortData = newData;
}
} else if (type instanceof IntegerType || type instanceof DateType ||
DecimalType.is32BitDecimalType(type)) {
if (intData == null || intData.length < newCapacity) {
int[] newData = new int[newCapacity];
- if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
+ if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity);
intData = newData;
}
} else if (type instanceof LongType || type instanceof TimestampType ||
DecimalType.is64BitDecimalType(type)) {
if (longData == null || longData.length < newCapacity) {
long[] newData = new long[newCapacity];
- if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
+ if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity);
longData = newData;
}
} else if (type instanceof FloatType) {
if (floatData == null || floatData.length < newCapacity) {
float[] newData = new float[newCapacity];
- if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
+ if (floatData != null) System.arraycopy(floatData, 0, newData, 0, capacity);
floatData = newData;
}
} else if (type instanceof DoubleType) {
if (doubleData == null || doubleData.length < newCapacity) {
double[] newData = new double[newCapacity];
- if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);
+ if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity);
doubleData = newData;
}
} else if (resultStruct != null) {
@@ -466,7 +476,7 @@ protected void reserveInternal(int newCapacity) {
}
byte[] newNulls = new byte[newCapacity];
- if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended);
+ if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity);
nulls = newNulls;
capacity = newCapacity;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
similarity index 94%
rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java
rename to sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
index 3e3997fa9bfec..d31790a285687 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
@@ -21,22 +21,18 @@
import scala.concurrent.duration.Duration;
-import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.execution.streaming.OneTimeTrigger$;
/**
- * :: Experimental ::
* Policy used to indicate how often results should be produced by a [[StreamingQuery]].
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
public class Trigger {
/**
- * :: Experimental ::
* A trigger policy that runs a query periodically based on an interval in processing time.
* If `interval` is 0, the query will run as fast as possible.
*
@@ -47,7 +43,6 @@ public static Trigger ProcessingTime(long intervalMs) {
}
/**
- * :: Experimental ::
* (Java-friendly)
* A trigger policy that runs a query periodically based on an interval in processing time.
* If `interval` is 0, the query will run as fast as possible.
@@ -64,7 +59,6 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) {
}
/**
- * :: Experimental ::
* (Scala-friendly)
* A trigger policy that runs a query periodically based on an interval in processing time.
* If `duration` is 0, the query will run as fast as possible.
@@ -80,7 +74,6 @@ public static Trigger ProcessingTime(Duration interval) {
}
/**
- * :: Experimental ::
* A trigger policy that runs a query periodically based on an interval in processing time.
* If `interval` is effectively 0, the query will run as fast as possible.
*
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 27d32b5dca431..0c5f3f22e31e8 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -5,3 +5,4 @@ org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
+org.apache.spark.sql.execution.streaming.RateSourceProvider
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 052d85ad33bd6..1d88992c48562 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -244,13 +244,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* import com.google.common.collect.ImmutableMap;
*
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
- * df.replace("height", ImmutableMap.of(1.0, 2.0));
+ * df.na.replace("height", ImmutableMap.of(1.0, 2.0));
*
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
- * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
+ * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
*
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
- * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
+ * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
* }}}
*
* @param col name of the column to apply the value replacement
@@ -271,10 +271,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* import com.google.common.collect.ImmutableMap;
*
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
- * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
+ * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
*
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
- * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
+ * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
* }}}
*
* @param cols list of columns to apply the value replacement
@@ -295,13 +295,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* {{{
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
- * df.replace("height", Map(1.0 -> 2.0))
+ * df.na.replace("height", Map(1.0 -> 2.0));
*
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
- * df.replace("name", Map("UNKNOWN" -> "unnamed")
+ * df.na.replace("name", Map("UNKNOWN" -> "unnamed"));
*
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
- * df.replace("*", Map("UNKNOWN" -> "unnamed")
+ * df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
* }}}
*
* @param col name of the column to apply the value replacement
@@ -324,10 +324,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* {{{
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
- * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
+ * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
*
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
- * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed");
+ * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"));
* }}}
*
* @param cols list of columns to apply the value replacement
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index c1b32917415ae..628a82fd23c13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -283,7 +283,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* Loads JSON files and returns the results as a `DataFrame`.
*
* JSON Lines (newline-delimited JSON) is supported by
- * default. For JSON (one record per file), set the `wholeFile` option to true.
+ * default. For JSON (one record per file), set the `multiLine` option to true.
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
@@ -323,7 +323,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
- * `wholeFile` (default `false`): parse one record, which may span multiple lines,
+ * `multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
*
*
@@ -525,7 +525,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
- * `wholeFile` (default `false`): parse one record, which may span multiple lines.
+ * `multiLine` (default `false`): parse one record, which may span multiple lines.
*
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 1732a8e08b73f..0259fffeab2db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand}
@@ -286,7 +286,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partition = Map.empty[String, Option[String]],
query = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
- ifNotExists = false)
+ ifPartitionNotExists = false)
}
}
@@ -372,8 +372,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
// Get all input data source or hive relations of the query.
val srcRelations = df.logicalPlan.collect {
case LogicalRelation(src: BaseRelation, _, _) => src
- case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) =>
- relation.tableMeta.identifier
+ case relation: HiveTableRelation => relation.tableMeta.identifier
}
val tableRelation = df.sparkSession.table(tableIdentWithDB).queryExecution.analyzed
@@ -383,8 +382,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
throw new AnalysisException(
s"Cannot overwrite table $tableName that is also being read from")
// check hive table relation when overwrite mode
- case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta)
- && srcRelations.contains(relation.tableMeta.identifier) =>
+ case relation: HiveTableRelation
+ if srcRelations.contains(relation.tableMeta.identifier) =>
throw new AnalysisException(
s"Cannot overwrite table $tableName that is also being read from")
case _ => // OK
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 520663f624408..4c13daa386a8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.sql.{Date, Timestamp}
-import java.util.TimeZone
import scala.collection.JavaConverters._
import scala.language.implicitConversions
@@ -36,10 +35,11 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
+import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.ParseException
@@ -132,7 +132,7 @@ private[sql] object Dataset {
*
* people.filter("age > 30")
* .join(department, people("deptId") === department("id"))
- * .groupBy(department("name"), "gender")
+ * .groupBy(department("name"), people("gender"))
* .agg(avg(people("salary")), max(people("age")))
* }}}
*
@@ -142,9 +142,9 @@ private[sql] object Dataset {
* Dataset people = spark.read().parquet("...");
* Dataset department = spark.read().parquet("...");
*
- * people.filter("age".gt(30))
- * .join(department, people.col("deptId").equalTo(department("id")))
- * .groupBy(department.col("name"), "gender")
+ * people.filter(people.col("age").gt(30))
+ * .join(department, people.col("deptId").equalTo(department.col("id")))
+ * .groupBy(department.col("name"), people.col("gender"))
* .agg(avg(people.col("salary")), max(people.col("age")));
* }}}
*
@@ -196,15 +196,10 @@ class Dataset[T] private[sql](
*/
private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
- /**
- * Encoder is used mostly as a container of serde expressions in Dataset. We build logical
- * plans by these serde expressions and execute it within the query framework. However, for
- * performance reasons we may want to use encoder as a function to deserialize internal rows to
- * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its
- * `fromRow` method later.
- */
- private val boundEnc =
- exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
+ // The deserializer expression which can be used to build a projection and turn rows to objects
+ // of type T, after collecting rows to the driver side.
+ private val deserializer =
+ exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer
private implicit def classTag = exprEnc.clsTag
@@ -247,7 +242,8 @@ class Dataset[T] private[sql](
val hasMoreData = takeResult.length > numRows
val data = takeResult.take(numRows)
- lazy val timeZone = TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
+ lazy val timeZone =
+ DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
@@ -372,6 +368,10 @@ class Dataset[T] private[sql](
* If the schema of the Dataset does not match the desired `U` type, you can use `select`
* along with `alias` or `as` to rearrange or rename as required.
*
+ * Note that `as[]` only changes the view of the data that is passed into typed operations,
+ * such as `map()`, and does not eagerly project away any columns that are not present in
+ * the specified class.
+ *
* @group basic
* @since 1.6.0
*/
@@ -484,7 +484,6 @@ class Dataset[T] private[sql](
* @group streaming
* @since 2.0.0
*/
- @Experimental
@InterfaceStability.Evolving
def isStreaming: Boolean = logicalPlan.isStreaming
@@ -545,7 +544,6 @@ class Dataset[T] private[sql](
}
/**
- * :: Experimental ::
* Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time
* before which we assume no more late data is going to arrive.
*
@@ -569,7 +567,6 @@ class Dataset[T] private[sql](
* @group streaming
* @since 2.1.0
*/
- @Experimental
@InterfaceStability.Evolving
// We only accept an existing column name, not a derived column here as a watermark that is
// defined on a derived column cannot referenced elsewhere in the plan.
@@ -579,7 +576,8 @@ class Dataset[T] private[sql](
.getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'"))
require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0,
s"delay threshold ($delayThreshold) should not be negative.")
- EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)
+ EliminateEventTimeWatermark(
+ EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan))
}
/**
@@ -903,7 +901,7 @@ class Dataset[T] private[sql](
* @param condition Join expression.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`,
- * `right`, `right_outer`, `left_semi`, `left_anti`.
+ * `right`, `right_outer`.
*
* @group typedrel
* @since 1.6.0
@@ -920,6 +918,10 @@ class Dataset[T] private[sql](
JoinType(joinType),
Some(condition.expr))).analyzed.asInstanceOf[Join]
+ if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) {
+ throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql)
+ }
+
// For both join side, combine all outputs into a single column and alias it with "_1" or "_2",
// to match the schema for the encoder of the join result.
// Note that we do this before joining them, to enable the join operator to return null for one
@@ -1026,7 +1028,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def sort(sortCol: String, sortCols: String*): Dataset[T] = {
- sort((sortCol +: sortCols).map(apply) : _*)
+ sort((sortCol +: sortCols).map(Column(_)) : _*)
}
/**
@@ -1073,6 +1075,22 @@ class Dataset[T] private[sql](
*/
def apply(colName: String): Column = col(colName)
+ /**
+ * Specifies some hint on the current Dataset. As an example, the following code specifies
+ * that one of the plan can be broadcasted:
+ *
+ * {{{
+ * df1.join(df2.hint("broadcast"))
+ * }}}
+ *
+ * @group basic
+ * @since 2.2.0
+ */
+ @scala.annotation.varargs
+ def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan {
+ UnresolvedHint(name, parameters, logicalPlan)
+ }
+
/**
* Selects column based on the column name and return it as a [[Column]].
*
@@ -1613,10 +1631,11 @@ class Dataset[T] private[sql](
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
- * This is equivalent to `UNION ALL` in SQL.
*
- * To do a SQL-style set union (that does deduplication of elements), use this function followed
- * by a [[distinct]].
+ * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
+ * deduplication of elements), use this function followed by a [[distinct]].
+ *
+ * Also as standard in SQL, this function resolves columns by position (not by name).
*
* @group typedrel
* @since 2.0.0
@@ -1626,10 +1645,11 @@ class Dataset[T] private[sql](
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
- * This is equivalent to `UNION ALL` in SQL.
*
- * To do a SQL-style set union (that does deduplication of elements), use this function followed
- * by a [[distinct]].
+ * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
+ * deduplication of elements), use this function followed by a [[distinct]].
+ *
+ * Also as standard in SQL, this function resolves columns by position (not by name).
*
* @group typedrel
* @since 2.0.0
@@ -1726,15 +1746,23 @@ class Dataset[T] private[sql](
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
- // ordering deterministic.
- // MapType cannot be sorted.
- val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType])
- .map(SortOrder(_, Ascending)), global = false, logicalPlan)
+ // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
+ // from the sort order.
+ val sortOrder = logicalPlan.output
+ .filter(attr => RowOrdering.isOrderable(attr.dataType))
+ .map(SortOrder(_, Ascending))
+ val plan = if (sortOrder.nonEmpty) {
+ Sort(sortOrder, global = false, logicalPlan)
+ } else {
+ // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
+ cache()
+ logicalPlan
+ }
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset[T](
- sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
+ sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder)
}.toArray
}
@@ -2390,7 +2418,15 @@ class Dataset[T] private[sql](
*/
def toLocalIterator(): java.util.Iterator[T] = {
withAction("toLocalIterator", queryExecution) { plan =>
- plan.executeToIterator().map(boundEnc.fromRow).asJava
+ // This projection writes output to a `InternalRow`, which means applying this projection is
+ // not thread-safe. Here we create the projection inside this method to make `Dataset`
+ // thread-safe.
+ val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
+ plan.executeToIterator().map { row =>
+ // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
+ // parameter of its `get` method, so it's safe to use null here.
+ objProj(row).get(0, null).asInstanceOf[T]
+ }.asJava
}
}
@@ -2632,6 +2668,22 @@ class Dataset[T] private[sql](
createTempViewCommand(viewName, replace = false, global = true)
}
+ /**
+ * Creates or replaces a global temporary view using the given name. The lifetime of this
+ * temporary view is tied to this Spark application.
+ *
+ * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
+ * i.e. it will be automatically dropped when the application terminates. It's tied to a system
+ * preserved database `_global_temp`, and we must use the qualified name to refer a global temp
+ * view, e.g. `SELECT * FROM _global_temp.view1`.
+ *
+ * @group basic
+ * @since 2.2.0
+ */
+ def createOrReplaceGlobalTempView(viewName: String): Unit = withPlan {
+ createTempViewCommand(viewName, replace = true, global = true)
+ }
+
private def createTempViewCommand(
viewName: String,
replace: Boolean,
@@ -2670,13 +2722,11 @@ class Dataset[T] private[sql](
}
/**
- * :: Experimental ::
* Interface for saving the content of the streaming Dataset out into external storage.
*
* @group basic
* @since 2.0.0
*/
- @Experimental
@InterfaceStability.Evolving
def writeStream: DataStreamWriter[T] = {
if (!isStreaming) {
@@ -2735,7 +2785,7 @@ class Dataset[T] private[sql](
fsBasedRelation.inputFiles
case fr: FileRelation =>
fr.inputFiles
- case r: CatalogRelation if DDLUtils.isHiveTable(r.tableMeta) =>
+ case r: HiveTableRelation =>
r.tableMeta.storage.locationUri.map(_.toString).toArray
}.flatten
files.toSet.toArray
@@ -2754,7 +2804,7 @@ class Dataset[T] private[sql](
EvaluatePython.javaToPython(rdd)
}
- private[sql] def collectToPython(): Int = {
+ private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
withNewExecutionId {
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
@@ -2764,7 +2814,7 @@ class Dataset[T] private[sql](
}
}
- private[sql] def toPythonIterator(): Int = {
+ private[sql] def toPythonIterator(): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
}
@@ -2778,7 +2828,7 @@ class Dataset[T] private[sql](
* Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with
* an execution.
*/
- private[sql] def withNewExecutionId[U](body: => U): U = {
+ private def withNewExecutionId[U](body: => U): U = {
SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body)
}
@@ -2809,7 +2859,14 @@ class Dataset[T] private[sql](
* Collect all elements from a spark plan.
*/
private def collectFromPlan(plan: SparkPlan): Array[T] = {
- plan.executeCollect().map(boundEnc.fromRow)
+ // This projection writes output to a `InternalRow`, which means applying this projection is not
+ // thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe.
+ val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
+ plan.executeCollect().map { row =>
+ // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
+ // parameter of its `get` method, so it's safe to use null here.
+ objProj(row).get(0, null).asInstanceOf[T]
+ }
}
private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
index 372ec262f5764..86e02e98c01f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
/**
- * :: Experimental ::
* A class to consume data generated by a `StreamingQuery`. Typically this is used to send the
* generated data to external systems. Each partition will use a new deserialized instance, so you
* usually should do all the initialization (e.g. opening a connection or initiating a transaction)
@@ -66,7 +65,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability}
* }}}
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
abstract class ForeachWriter[T] extends Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cc2983987eb90..7fde6e9469e5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -505,7 +505,6 @@ class SQLContext private[sql](val sparkSession: SparkSession)
/**
- * :: Experimental ::
* Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
* {{{
* sparkSession.readStream.parquet("/path/to/directory/of/parquet/files")
@@ -514,7 +513,6 @@ class SQLContext private[sql](val sparkSession: SparkSession)
*
* @since 2.0.0
*/
- @Experimental
@InterfaceStability.Evolving
def readStream: DataStreamReader = sparkSession.readStream
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 375df64d39734..17671ea8685b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -111,93 +111,60 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/**
* @since 1.6.1
- * @deprecated use [[newIntSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newLongSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newDoubleSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newFloatSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newByteSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newShortSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newBooleanSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newStringSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/**
* @since 1.6.1
- * @deprecated use [[newProductSequenceEncoder]]
+ * @deprecated use [[newSequenceEncoder]]
*/
- implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
+ def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
/** @since 2.2.0 */
- implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
-
- /** @since 2.2.0 */
- implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
- ExpressionEncoder()
+ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
// Arrays
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 95f3463dfe62b..96882c62c2d67 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -32,13 +32,14 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
-import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
+import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
@@ -77,11 +78,12 @@ import org.apache.spark.util.Utils
class SparkSession private(
@transient val sparkContext: SparkContext,
@transient private val existingSharedState: Option[SharedState],
- @transient private val parentSessionState: Option[SessionState])
+ @transient private val parentSessionState: Option[SessionState],
+ @transient private[sql] val extensions: SparkSessionExtensions)
extends Serializable with Closeable with Logging { self =>
private[sql] def this(sc: SparkContext) {
- this(sc, None, None)
+ this(sc, None, None, new SparkSessionExtensions)
}
sparkContext.assertNotStopped()
@@ -111,6 +113,12 @@ class SparkSession private(
existingSharedState.getOrElse(new SharedState(sparkContext))
}
+ /**
+ * Initial options for session. This options are applied once when sessionState is created.
+ */
+ @transient
+ private[sql] val initialSessionOptions = new scala.collection.mutable.HashMap[String, String]
+
/**
* State isolated across sessions, including SQL configurations, temporary tables, registered
* functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
@@ -126,9 +134,11 @@ class SparkSession private(
parentSessionState
.map(_.clone(this))
.getOrElse {
- SparkSession.instantiateSessionState(
+ val state = SparkSession.instantiateSessionState(
SparkSession.sessionStateClassName(sparkContext.conf),
self)
+ initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) }
+ state
}
}
@@ -219,7 +229,7 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
- new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
+ new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions)
}
/**
@@ -235,7 +245,7 @@ class SparkSession private(
* implementation is Hive, this will initialize the metastore, which may take some time.
*/
private[sql] def cloneSession(): SparkSession = {
- val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
+ val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions)
result.sessionState // force copy of SessionState
result
}
@@ -605,7 +615,7 @@ class SparkSession private(
}
private[sql] def table(tableIdent: TableIdentifier): DataFrame = {
- Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent))
+ Dataset.ofRows(self, UnresolvedRelation(tableIdent))
}
/* ----------------- *
@@ -635,7 +645,6 @@ class SparkSession private(
def read: DataFrameReader = new DataFrameReader(self)
/**
- * :: Experimental ::
* Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
* {{{
* sparkSession.readStream.parquet("/path/to/directory/of/parquet/files")
@@ -644,7 +653,6 @@ class SparkSession private(
*
* @since 2.0.0
*/
- @Experimental
@InterfaceStability.Evolving
def readStream: DataStreamReader = new DataStreamReader(self)
@@ -754,6 +762,8 @@ object SparkSession {
private[this] val options = new scala.collection.mutable.HashMap[String, String]
+ private[this] val extensions = new SparkSessionExtensions
+
private[this] var userSuppliedContext: Option[SparkContext] = None
private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
@@ -847,6 +857,17 @@ object SparkSession {
}
}
+ /**
+ * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules,
+ * Optimizer rules, Planning Strategies or a customized parser.
+ *
+ * @since 2.2.0
+ */
+ def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
+ f(extensions)
+ this
+ }
+
/**
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
* one based on the options set in this builder.
@@ -903,8 +924,27 @@ object SparkSession {
}
sc
}
- session = new SparkSession(sparkContext)
- options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
+
+ // Initialize extensions if the user has defined a configurator class.
+ val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
+ if (extensionConfOption.isDefined) {
+ val extensionConfClassName = extensionConfOption.get
+ try {
+ val extensionConfClass = Utils.classForName(extensionConfClassName)
+ val extensionConf = extensionConfClass.newInstance()
+ .asInstanceOf[SparkSessionExtensions => Unit]
+ extensionConf(extensions)
+ } catch {
+ // Ignore the error if we cannot find the class or when the class has the wrong type.
+ case e @ (_: ClassCastException |
+ _: ClassNotFoundException |
+ _: NoClassDefFoundError) =>
+ logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
+ }
+ }
+
+ session = new SparkSession(sparkContext, None, None, extensions)
+ options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
defaultSession.set(session)
// Register a successfully instantiated context to the singleton. This should be at the
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
new file mode 100644
index 0000000000000..f99c108161f94
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * :: Experimental ::
+ * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability
+ * regarding binary compatibility and source compatibility of methods here.
+ *
+ * This current provides the following extension points:
+ * - Analyzer Rules.
+ * - Check Analysis Rules
+ * - Optimizer Rules.
+ * - Planning Strategies.
+ * - Customized Parser.
+ * - (External) Catalog listeners.
+ *
+ * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for
+ * example:
+ * {{{
+ * SparkSession.builder()
+ * .master("...")
+ * .conf("...", true)
+ * .withExtensions { extensions =>
+ * extensions.injectResolutionRule { session =>
+ * ...
+ * }
+ * extensions.injectParser { (session, parser) =>
+ * ...
+ * }
+ * }
+ * .getOrCreate()
+ * }}}
+ *
+ * Note that none of the injected builders should assume that the [[SparkSession]] is fully
+ * initialized and should not touch the session's internals (e.g. the SessionState).
+ */
+@DeveloperApi
+@Experimental
+@InterfaceStability.Unstable
+class SparkSessionExtensions {
+ type RuleBuilder = SparkSession => Rule[LogicalPlan]
+ type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
+ type StrategyBuilder = SparkSession => Strategy
+ type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
+
+ private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+ /**
+ * Build the analyzer resolution `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ resolutionRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer
+ * rules will be executed as part of the resolution phase of analysis.
+ */
+ def injectResolutionRule(builder: RuleBuilder): Unit = {
+ resolutionRuleBuilders += builder
+ }
+
+ private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+ /**
+ * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ postHocResolutionRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer
+ * rules will be executed after resolution.
+ */
+ def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
+ postHocResolutionRuleBuilders += builder
+ }
+
+ private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder]
+
+ /**
+ * Build the check analysis `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = {
+ checkRuleBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will
+ * be executed after the analysis phase. A check analysis rule is used to detect problems with a
+ * LogicalPlan and should throw an exception when a problem is found.
+ */
+ def injectCheckRule(builder: CheckRuleBuilder): Unit = {
+ checkRuleBuilders += builder
+ }
+
+ private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]
+
+ private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ optimizerRules.map(_.apply(session))
+ }
+
+ /**
+ * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be
+ * executed during the operator optimization batch. An optimizer rule is used to improve the
+ * quality of an analyzed logical plan; these rules should never modify the result of the
+ * LogicalPlan.
+ */
+ def injectOptimizerRule(builder: RuleBuilder): Unit = {
+ optimizerRules += builder
+ }
+
+ private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
+
+ private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
+ plannerStrategyBuilders.map(_.apply(session))
+ }
+
+ /**
+ * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will
+ * be used to convert a `LogicalPlan` into a executable
+ * [[org.apache.spark.sql.execution.SparkPlan]].
+ */
+ def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
+ plannerStrategyBuilders += builder
+ }
+
+ private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder]
+
+ private[sql] def buildParser(
+ session: SparkSession,
+ initial: ParserInterface): ParserInterface = {
+ parserBuilders.foldLeft(initial) { (parser, builder) =>
+ builder(session, parser)
+ }
+ }
+
+ /**
+ * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session
+ * and an initial parser. The latter allows for a user to create a partial parser and to delegate
+ * to the underlying parser for completeness. If a user injects more parsers, then the parsers
+ * are stacked on top of each other.
+ */
+ def injectParser(builder: ParserBuilder): Unit = {
+ parserBuilders += builder
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index a57673334c10b..6accf1f75064c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -70,15 +70,31 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @param name the name of the UDAF.
* @param udaf the UDAF needs to be registered.
* @return the registered UDAF.
+ *
+ * @since 1.5.0
*/
- def register(
- name: String,
- udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+ def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
functionRegistry.registerFunction(name, builder)
udaf
}
+ /**
+ * Register a user-defined function (UDF), for a UDF that's already defined using the DataFrame
+ * API (i.e. of type UserDefinedFunction).
+ *
+ * @param name the name of the UDF.
+ * @param udf the UDF needs to be registered.
+ * @return the registered UDF.
+ *
+ * @since 2.2.0
+ */
+ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
+ def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
+ functionRegistry.registerFunction(name, builder)
+ udf
+ }
+
// scalastyle:off line.size.limit
/* register 0-22 were generated by this script
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 866fa98533218..6fb41b6425c4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -67,7 +67,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
* Shorthand for calling redactString() without specifying redacting rules
*/
private def redact(text: String): String = {
- Utils.redact(SparkSession.getActiveSession.get.sparkContext.conf, text)
+ Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text)
}
}
@@ -519,8 +519,8 @@ case class FileSourceScanExec(
relation,
output.map(QueryPlan.normalizeExprId(_, output)),
requiredSchema,
- partitionFilters.map(QueryPlan.normalizeExprId(_, output)),
- dataFilters.map(QueryPlan.normalizeExprId(_, output)),
+ QueryPlan.normalizePredicates(partitionFilters, output),
+ QueryPlan.normalizePredicates(dataFilters, output),
None)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
index 458ac4ba3637c..01c9c65e5399d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -31,16 +31,16 @@ import org.apache.spark.storage.BlockManager
import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
/**
- * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
- * threshold of rows is reached.
+ * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array
+ * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which
+ * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is
+ * excessive memory consumption). Setting these threshold involves following trade-offs:
*
- * Setting spill threshold faces following trade-off:
- *
- * - If the spill threshold is too high, the in-memory array may occupy more memory than is
- * available, resulting in OOM.
- * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
- * This may lead to a performance regression compared to the normal case of using an
- * [[ArrayBuffer]] or [[Array]].
+ * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory
+ * than is available, resulting in OOM.
+ * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to
+ * excessive disk writes. This may lead to a performance regression compared to the normal case
+ * of using an [[ArrayBuffer]] or [[Array]].
*/
private[sql] class ExternalAppendOnlyUnsafeRowArray(
taskMemoryManager: TaskMemoryManager,
@@ -49,9 +49,10 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
taskContext: TaskContext,
initialSize: Int,
pageSizeBytes: Long,
+ numRowsInMemoryBufferThreshold: Int,
numRowsSpillThreshold: Int) extends Logging {
- def this(numRowsSpillThreshold: Int) {
+ def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) {
this(
TaskContext.get().taskMemoryManager(),
SparkEnv.get.blockManager,
@@ -59,11 +60,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
TaskContext.get(),
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
+ numRowsInMemoryBufferThreshold,
numRowsSpillThreshold)
}
private val initialSizeOfInMemoryBuffer =
- Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+ Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold)
private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
@@ -102,11 +104,11 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray(
}
def add(unsafeRow: UnsafeRow): Unit = {
- if (numRows < numRowsSpillThreshold) {
+ if (numRows < numRowsInMemoryBufferThreshold) {
inMemoryBuffer += unsafeRow.copy()
} else {
if (spillableArray == null) {
- logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+ logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " +
s"${classOf[UnsafeExternalSorter].getName}")
// We will not sort the rows, so prefixComparator and recordComparator are null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index f87d05884b276..c35e5638e9273 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow])
extends Iterator[InternalRow] {
- lazy val results = func().toIterator
+ lazy val results: Iterator[InternalRow] = func().toIterator
override def hasNext: Boolean = results.hasNext
override def next(): InternalRow = results.next()
}
@@ -50,7 +50,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
- * given `generator` is empty. `outer` has no effect when `join` is false.
+ * given `generator` is empty.
* @param generatorOutput the qualified output attributes of the generator of this node, which
* constructed in analysis phase, and we can not change it, as the
* parent node bound with it already.
@@ -78,15 +78,15 @@ case class GenerateExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
- val boundGenerator = BindReferences.bindReference(generator, child.output)
+ lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output)
protected override def doExecute(): RDD[InternalRow] = {
// boundGenerator.terminate() should be triggered after all of the rows in the partition
- val rows = if (join) {
- child.execute().mapPartitionsInternal { iter =>
- val generatorNullRow = new GenericInternalRow(generator.elementSchema.length)
+ val numOutputRows = longMetric("numOutputRows")
+ child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
+ val generatorNullRow = new GenericInternalRow(generator.elementSchema.length)
+ val rows = if (join) {
val joinedRow = new JoinedRow
-
iter.flatMap { row =>
// we should always set the left (child output)
joinedRow.withLeft(row)
@@ -101,18 +101,21 @@ case class GenerateExec(
// keep it the same as Hive does
joinedRow.withRight(row)
}
+ } else {
+ iter.flatMap { row =>
+ val outputRows = boundGenerator.eval(row)
+ if (outer && outputRows.isEmpty) {
+ Seq(generatorNullRow)
+ } else {
+ outputRows
+ }
+ } ++ LazyIterator(boundGenerator.terminate)
}
- } else {
- child.execute().mapPartitionsInternal { iter =>
- iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate)
- }
- }
- val numOutputRows = longMetric("numOutputRows")
- rows.mapPartitionsWithIndexInternal { (index, iter) =>
+ // Convert the rows to unsafe rows.
val proj = UnsafeProjection.create(output, output)
proj.initialize(index)
- iter.map { r =>
+ rows.map { r =>
numOutputRows += 1
proj(r)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 19c68c13262a5..514ad7018d8c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -28,12 +28,12 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
*/
case class LocalTableScanExec(
output: Seq[Attribute],
- rows: Seq[InternalRow]) extends LeafExecNode {
+ @transient rows: Seq[InternalRow]) extends LeafExecNode {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
- private lazy val unsafeRows: Array[InternalRow] = {
+ @transient private lazy val unsafeRows: Array[InternalRow] = {
if (rows.isEmpty) {
Array.empty
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
index 3c046ce494285..d59b3c6f0caf2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -101,7 +101,7 @@ case class OptimizeMetadataOnlyQuery(
val partitionData = fsRelation.location.listFiles(Nil, Nil)
LocalRelation(partAttrs, partitionData.map(_.values))
- case relation: CatalogRelation =>
+ case relation: HiveTableRelation =>
val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
val caseInsensitiveProperties =
CaseInsensitiveMap(relation.tableMeta.storage.properties)
@@ -137,7 +137,7 @@ case class OptimizeMetadataOnlyQuery(
val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
Some(AttributeSet(partAttrs), l)
- case relation: CatalogRelation if relation.tableMeta.partitionColumnNames.nonEmpty =>
+ case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty =>
val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
Some(AttributeSet(partAttrs), relation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 8e8210e334a1d..2e05e5d65923c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
-import java.util.TimeZone
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
@@ -187,7 +186,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
case (t: Timestamp, TimestampType) =>
DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t),
- TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone))
+ DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone))
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index be35916e3447e..bde7d61b20dc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -94,7 +94,7 @@ object SQLExecution {
/**
* Wrap an action with a known executionId. When running a different action in a different
* thread from the original one, this method can be used to connect the Spark jobs in this action
- * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`.
+ * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
*/
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index cadab37a449aa..4c17a24d14317 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -22,6 +22,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
+import org.codehaus.commons.compiler.CompileException
+import org.codehaus.janino.InternalCompilerException
+
import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
@@ -353,9 +356,27 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
}
+ private def genInterpretedPredicate(
+ expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = {
+ val str = expression.toString
+ val logMessage = if (str.length > 256) {
+ str.substring(0, 256 - 3) + "..."
+ } else {
+ str
+ }
+ logWarning(s"Codegen disabled for this expression:\n $logMessage")
+ InterpretedPredicate.create(expression, inputSchema)
+ }
+
protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
- GeneratePredicate.generate(expression, inputSchema)
+ try {
+ GeneratePredicate.generate(expression, inputSchema)
+ } catch {
+ case e @ (_: InternalCompilerException | _: CompileException)
+ if sqlContext == null || sqlContext.conf.wholeStageFallback =>
+ genInterpretedPredicate(expression, inputSchema)
+ }
}
protected def newOrdering(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 6566502bd8a8a..4e718d609c921 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -36,7 +36,7 @@ class SparkPlanner(
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
FileSourceStrategy ::
- DataSourceStrategy ::
+ DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
JoinSelection ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 20dacf88504f1..c2c52894860b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
/**
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
*/
-class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
+class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
import org.apache.spark.sql.catalyst.parser.ParserUtils._
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ca2f6dd7a84b2..843ce63161220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Matches a plan whose output should be small enough to be used in broadcast join.
*/
private def canBroadcast(plan: LogicalPlan): Boolean = {
- plan.stats(conf).isBroadcastable ||
+ plan.stats(conf).hints.isBroadcastable.getOrElse(false) ||
(plan.stats(conf).sizeInBytes >= 0 &&
plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold)
}
@@ -383,8 +383,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
- f, key, value, grouping, data, output, _, _, _, _, child) =>
- execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
+ f, key, value, grouping, data, output, _, _, _, timeout, child) =>
+ execution.MapGroupsExec(
+ f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
@@ -432,7 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
- case BroadcastHint(child) => planLater(child) :: Nil
+ case h: ResolvedHint => planLater(h.child) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index c1e1a631c677e..974315db584da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
* Inserts an InputAdapter on top of those that do not support codegen.
*/
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
- case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen =>
- // The children of SortMergeJoin should do codegen separately.
- j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
- right = InputAdapter(insertWholeStageCodegen(right)))
case p if !supportCodegen(p) =>
// collapse them recursively
InputAdapter(insertWholeStageCodegen(p))
+ case j @ SortMergeJoinExec(_, _, _, _, left, right) =>
+ // The children of SortMergeJoin should do codegen separately.
+ j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+ right = InputAdapter(insertWholeStageCodegen(right)))
case p =>
p.withNewChildren(p.children.map(insertInputAdapter))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 68c8e6ce62cbb..8e0e27f4dd3a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -310,7 +310,7 @@ case class HashAggregateExec(
initialBuffer,
bufferSchema,
groupingKeySchema,
- TaskContext.get().taskMemoryManager(),
+ TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index 3a7fcf1fa9d89..6e47f9d611199 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
@@ -39,7 +40,8 @@ class ObjectAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
originalInputAttributes: Seq[Attribute],
inputRows: Iterator[InternalRow],
- fallbackCountThreshold: Int)
+ fallbackCountThreshold: Int,
+ numOutputRows: SQLMetric)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
@@ -83,7 +85,9 @@ class ObjectAggregationIterator(
override final def next(): UnsafeRow = {
val entry = aggBufferIterator.next()
- generateOutput(entry.groupingKey, entry.aggregationBuffer)
+ val res = generateOutput(entry.groupingKey, entry.aggregationBuffer)
+ numOutputRows += 1
+ res
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index 3fcb7ec9a6411..b69500d592ba2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -94,6 +94,8 @@ case class ObjectHashAggregateExec(
}
}
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
@@ -117,7 +119,8 @@ case class ObjectHashAggregateExec(
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
child.output,
iter,
- fallbackCountThreshold)
+ fallbackCountThreshold,
+ numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index 9316ebcdf105c..3718424931b40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -50,10 +50,10 @@ class RowBasedHashMapGenerator(
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
@@ -63,10 +63,10 @@ class RowBasedHashMapGenerator(
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 2988161ee5e7b..670c33d03b5b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -160,7 +160,7 @@ class TungstenAggregationIterator(
initialAggregationBuffer,
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
- TaskContext.get().taskMemoryManager(),
+ TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 0c40417db0837..79ae1c010ebce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -55,10 +55,10 @@ class VectorizedHashMapGenerator(
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
@@ -68,10 +68,10 @@ class VectorizedHashMapGenerator(
val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 44278e37c5276..bd7a5c5d914c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext}
-import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
+import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
@@ -331,29 +331,32 @@ case class SampleExec(
case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
extends LeafExecNode with CodegenSupport {
- def start: Long = range.start
- def step: Long = range.step
- def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
- def numElements: BigInt = range.numElements
+ val start: Long = range.start
+ val end: Long = range.end
+ val step: Long = range.step
+ val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
+ val numElements: BigInt = range.numElements
override val output: Seq[Attribute] = range.output
override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
- "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override lazy val canonicalized: SparkPlan = {
RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range])
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
- sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
- .map(i => InternalRow(i)) :: Nil
+ val rdd = if (start == end || (start < end ^ 0 < step)) {
+ new EmptyRDD[InternalRow](sqlContext.sparkContext)
+ } else {
+ sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+ }
+ rdd :: Nil
}
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- val numGenerated = metricTerm(ctx, "numGeneratedRows")
val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
@@ -463,9 +466,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| $number = $batchEnd;
| }
|
- | if ($taskContext.isInterrupted()) {
- | throw new TaskKilledException();
- | }
+ | $taskContext.killTaskIfInterrupted();
|
| long $nextBatchTodo;
| if ($numElementsTodo > ${batchSize}L) {
@@ -540,7 +541,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
}
}
- override def simpleString: String = range.simpleString
+ override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)"
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 214e8d309de11..7c2c13e9e98ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -42,7 +42,9 @@ case class InMemoryTableScanExec(
override def output: Seq[Attribute] = attributes
private def updateAttribute(expr: Expression): Expression = {
- val attrMap = AttributeMap(relation.child.output.zip(output))
+ // attributes can be pruned so using relation's output.
+ // E.g., relation.output is [id, item] but this scan's output can be [item] only.
+ val attrMap = AttributeMap(relation.child.output.zip(relation.output))
expr.transform {
case attr: Attribute => attrMap.getOrElse(attr, attr)
}
@@ -100,7 +102,8 @@ case class InMemoryTableScanExec(
case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
- case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) =>
+ case In(a: AttributeReference, list: Seq[Expression])
+ if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty =>
list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index d2ea0cdf61aa6..bf7c22761dc04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.command
+import java.net.URI
+
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -45,10 +47,10 @@ case class AnalyzeTableCommand(
}
val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta)
- val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L)
+ val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(-1L)
val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L)
var newStats: Option[CatalogStatistics] = None
- if (newTotalSize > 0 && newTotalSize != oldTotalSize) {
+ if (newTotalSize >= 0 && newTotalSize != oldTotalSize) {
newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize))
}
// We only set rowCount when noscan is false, because otherwise:
@@ -81,6 +83,21 @@ case class AnalyzeTableCommand(
object AnalyzeTableCommand extends Logging {
def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = {
+ if (catalogTable.partitionColumnNames.isEmpty) {
+ calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri)
+ } else {
+ // Calculate table size as a sum of the visible partitions. See SPARK-21079
+ val partitions = sessionState.catalog.listPartitions(catalogTable.identifier)
+ partitions.map(p =>
+ calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri)
+ ).sum
+ }
+ }
+
+ private def calculateLocationSize(
+ sessionState: SessionState,
+ tableId: TableIdentifier,
+ locationUri: Option[URI]): Long = {
// This method is mainly based on
// org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table)
// in Hive 0.13 (except that we do not use fs.getContentSummary).
@@ -91,13 +108,13 @@ object AnalyzeTableCommand extends Logging {
// countFileSize to count the table size.
val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging")
- def calculateTableSize(fs: FileSystem, path: Path): Long = {
+ def calculateLocationSize(fs: FileSystem, path: Path): Long = {
val fileStatus = fs.getFileStatus(path)
val size = if (fileStatus.isDirectory) {
fs.listStatus(path)
.map { status =>
if (!status.getPath.getName.startsWith(stagingDir)) {
- calculateTableSize(fs, status.getPath)
+ calculateLocationSize(fs, status.getPath)
} else {
0L
}
@@ -109,16 +126,16 @@ object AnalyzeTableCommand extends Logging {
size
}
- catalogTable.storage.locationUri.map { p =>
+ locationUri.map { p =>
val path = new Path(p)
try {
val fs = path.getFileSystem(sessionState.newHadoopConf())
- calculateTableSize(fs, path)
+ calculateLocationSize(fs, path)
} catch {
case NonFatal(e) =>
logWarning(
- s"Failed to get the size of table ${catalogTable.identifier.table} in the " +
- s"database ${catalogTable.identifier.database} because of ${e.toString}", e)
+ s"Failed to get the size of table ${tableId.table} in the " +
+ s"database ${tableId.database} because of ${e.toString}", e)
0L
}
}.getOrElse(0L)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index 336f14dd97aea..cfcf3ac6f77f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -56,10 +56,8 @@ case class UncacheTableCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val tableId = tableIdent.quotedString
- try {
+ if (!ifExists || sparkSession.catalog.tableExists(tableId)) {
sparkSession.catalog.uncacheTable(tableId)
- } catch {
- case _: NoSuchTableException if ifExists => // don't throw
}
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index 2d890118ae0a5..d05af89df38db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
+import org.apache.spark.sql.types.StructType
/**
* A command used to create a data source table.
@@ -87,14 +88,32 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo
}
}
- val newTable = table.copy(
- schema = dataSource.schema,
- partitionColumnNames = partitionColumnNames,
- // If metastore partition management for file source tables is enabled, we start off with
- // partition provider hive, but no partitions in the metastore. The user has to call
- // `msck repair table` to populate the table partitions.
- tracksPartitionsInCatalog = partitionColumnNames.nonEmpty &&
- sessionState.conf.manageFilesourcePartitions)
+ val newTable = dataSource match {
+ // Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid
+ // inferring the schema again at read path. However if the data source has overlapped columns
+ // between data and partition schema, we can't store it in metastore as it breaks the
+ // assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store
+ // empty schema in metastore and infer it at runtime. Note that this also means the new
+ // scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case.
+ case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty =>
+ logWarning("It is not recommended to create a table with overlapped data and partition " +
+ "columns, as Spark cannot store a valid table schema and has to infer it at runtime, " +
+ "which hurts performance. Please check your data files and remove the partition " +
+ "columns in it.")
+ table.copy(schema = new StructType(), partitionColumnNames = Nil)
+
+ case _ =>
+ table.copy(
+ schema = dataSource.schema,
+ partitionColumnNames = partitionColumnNames,
+ // If metastore partition management for file source tables is enabled, we start off with
+ // partition provider hive, but no partitions in the metastore. The user has to call
+ // `msck repair table` to populate the table partitions.
+ tracksPartitionsInCatalog = partitionColumnNames.nonEmpty &&
+ sessionState.conf.manageFilesourcePartitions)
+
+ }
+
// We will return Nil or throw exception at the beginning if the table already exists, so when
// we reach here, the table should not exist and we should set `ignoreIfExists` to false.
sessionState.catalog.createTable(newTable, ignoreIfExists = false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 55540563ef911..b543c63870f43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -21,7 +21,6 @@ import java.util.Locale
import scala.collection.{GenMap, GenSeq}
import scala.collection.parallel.ForkJoinTaskSupport
-import scala.concurrent.forkjoin.ForkJoinPool
import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
@@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.types._
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
// Note: The definition of these commands are based on the ones described in
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
@@ -200,14 +199,20 @@ case class DropTableCommand(
case _ =>
}
}
- try {
- sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
- } catch {
- case _: NoSuchTableException if ifExists =>
- case NonFatal(e) => log.warn(e.toString, e)
+
+ if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) {
+ try {
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
+ } catch {
+ case NonFatal(e) => log.warn(e.toString, e)
+ }
+ catalog.refreshTable(tableName)
+ catalog.dropTable(tableName, ifExists, purge)
+ } else if (ifExists) {
+ // no-op
+ } else {
+ throw new AnalysisException(s"Table or view not found: ${tableName.identifier}")
}
- catalog.refreshTable(tableName)
- catalog.dropTable(tableName, ifExists, purge)
Seq.empty[Row]
}
}
@@ -299,8 +304,8 @@ case class AlterTableChangeColumnCommand(
val resolver = sparkSession.sessionState.conf.resolver
DDLUtils.verifyAlterTableType(catalog, table, isView = false)
- // Find the origin column from schema by column name.
- val originColumn = findColumnByName(table.schema, columnName, resolver)
+ // Find the origin column from dataSchema by column name.
+ val originColumn = findColumnByName(table.dataSchema, columnName, resolver)
// Throw an AnalysisException if the column name/dataType is changed.
if (!columnEqual(originColumn, newColumn, resolver)) {
throw new AnalysisException(
@@ -309,7 +314,7 @@ case class AlterTableChangeColumnCommand(
s"'${newColumn.name}' with type '${newColumn.dataType}'")
}
- val newSchema = table.schema.fields.map { field =>
+ val newDataSchema = table.dataSchema.fields.map { field =>
if (field.name == originColumn.name) {
// Create a new column from the origin column with the new comment.
addComment(field, newColumn.getComment)
@@ -317,8 +322,7 @@ case class AlterTableChangeColumnCommand(
field
}
}
- val newTable = table.copy(schema = StructType(newSchema))
- catalog.alterTable(newTable)
+ catalog.alterTableDataSchema(tableName, StructType(newDataSchema))
Seq.empty[Row]
}
@@ -330,7 +334,8 @@ case class AlterTableChangeColumnCommand(
schema.fields.collectFirst {
case field if resolver(field.name, name) => field
}.getOrElse(throw new AnalysisException(
- s"Invalid column reference '$name', table schema is '${schema}'"))
+ s"Can't find column `$name` given table data columns " +
+ s"${schema.fieldNames.mkString("[`", "`, `", "`]")}"))
}
// Add the comment to a column, if comment is empty, return the original column.
@@ -582,8 +587,15 @@ case class AlterTableRecoverPartitionsCommand(
val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt
val hadoopConf = spark.sparkContext.hadoopConfiguration
val pathFilter = getPathFilter(hadoopConf)
- val partitionSpecsAndLocs = scanPartitions(spark, fs, pathFilter, root, Map(),
- table.partitionColumnNames, threshold, spark.sessionState.conf.resolver)
+
+ val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
+ val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] =
+ try {
+ scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold,
+ spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq
+ } finally {
+ evalPool.shutdown()
+ }
val total = partitionSpecsAndLocs.length
logInfo(s"Found $total partitions in $root")
@@ -604,8 +616,6 @@ case class AlterTableRecoverPartitionsCommand(
Seq.empty[Row]
}
- @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8))
-
private def scanPartitions(
spark: SparkSession,
fs: FileSystem,
@@ -614,7 +624,8 @@ case class AlterTableRecoverPartitionsCommand(
spec: TablePartitionSpec,
partitionNames: Seq[String],
threshold: Int,
- resolver: Resolver): GenSeq[(TablePartitionSpec, Path)] = {
+ resolver: Resolver,
+ evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = {
if (partitionNames.isEmpty) {
return Seq(spec -> path)
}
@@ -638,7 +649,7 @@ case class AlterTableRecoverPartitionsCommand(
val value = ExternalCatalogUtils.unescapePathName(ps(1))
if (resolver(columnName, partitionNames.head)) {
scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value),
- partitionNames.drop(1), threshold, resolver)
+ partitionNames.drop(1), threshold, resolver, evalTaskSupport)
} else {
logWarning(
s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index ebf03e1bf8869..126c1cbc3109a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -187,11 +187,10 @@ case class AlterTableRenameCommand(
*/
case class AlterTableAddColumnsCommand(
table: TableIdentifier,
- columns: Seq[StructField]) extends RunnableCommand {
+ colsToAdd: Seq[StructField]) extends RunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val catalogTable = verifyAlterTableAddColumn(catalog, table)
-
try {
sparkSession.catalog.uncacheTable(table.quotedString)
} catch {
@@ -199,12 +198,7 @@ case class AlterTableAddColumnsCommand(
log.warn(s"Exception when attempting to uncache table ${table.quotedString}", e)
}
catalog.refreshTable(table)
-
- // make sure any partition columns are at the end of the fields
- val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema
- catalog.alterTableSchema(
- table, catalogTable.schema.copy(fields = reorderedSchema.toArray))
-
+ catalog.alterTableDataSchema(table, StructType(catalogTable.dataSchema ++ colsToAdd))
Seq.empty[Row]
}
@@ -339,7 +333,7 @@ case class LoadDataCommand(
uri
} else {
val uri = new URI(path)
- if (uri.getScheme() != null && uri.getAuthority() != null) {
+ val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) {
uri
} else {
// Follow Hive's behavior:
@@ -379,6 +373,13 @@ case class LoadDataCommand(
}
new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment())
}
+ val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ val srcPath = new Path(hdfsUri)
+ val fs = srcPath.getFileSystem(hadoopConf)
+ if (!fs.exists(srcPath)) {
+ throw new AnalysisException(s"LOAD DATA input path does not exist: $path")
+ }
+ hdfsUri
}
if (partition.nonEmpty) {
@@ -522,15 +523,15 @@ case class DescribeTableCommand(
throw new AnalysisException(
s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}")
}
- describeSchema(catalog.lookupRelation(table).schema, result)
+ describeSchema(catalog.lookupRelation(table).schema, result, header = false)
} else {
val metadata = catalog.getTableMetadata(table)
if (metadata.schema.isEmpty) {
// In older version(prior to 2.1) of Spark, the table schema can be empty and should be
// inferred at runtime. We should still support it.
- describeSchema(sparkSession.table(metadata.identifier).schema, result)
+ describeSchema(sparkSession.table(metadata.identifier).schema, result, header = false)
} else {
- describeSchema(metadata.schema, result)
+ describeSchema(metadata.schema, result, header = false)
}
describePartitionInfo(metadata, result)
@@ -550,7 +551,7 @@ case class DescribeTableCommand(
private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = {
if (table.partitionColumnNames.nonEmpty) {
append(buffer, "# Partition Information", "", "")
- describeSchema(table.partitionSchema, buffer)
+ describeSchema(table.partitionSchema, buffer, header = true)
}
}
@@ -601,8 +602,13 @@ case class DescribeTableCommand(
table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, ""))
}
- private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = {
- append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
+ private def describeSchema(
+ schema: StructType,
+ buffer: ArrayBuffer[Row],
+ header: Boolean): Unit = {
+ if (header) {
+ append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
+ }
schema.foreach { column =>
append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index 00f0acab21aa2..3518ee581c5fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -159,7 +159,9 @@ case class CreateViewCommand(
checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent)
// Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...`
- catalog.alterTable(prepareTable(sparkSession, analyzedPlan))
+ // Nothing we need to retain from the old view, so just drop and create a new one
+ catalog.dropTable(viewIdent, ignoreIfNotExists = false, purge = false)
+ catalog.createTable(prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false)
} else {
// Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already
// exists.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index f3b209deaae5c..9652f7c25a204 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
import scala.language.{existentials, implicitConversions}
import scala.util.{Failure, Success, Try}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.deploy.SparkHadoopUtil
@@ -39,6 +40,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
+import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.Utils
/**
@@ -122,7 +124,7 @@ case class DataSource(
val hdfsPath = new Path(path)
val fs = hdfsPath.getFileSystem(hadoopConf)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- SparkHadoopUtil.get.globPathIfNecessary(qualified)
+ SparkHadoopUtil.get.globPathIfNecessary(fs, qualified)
}.toArray
new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache)
}
@@ -181,6 +183,11 @@ case class DataSource(
throw new AnalysisException(
s"Unable to infer schema for $format. It must be specified manually.")
}
+
+ SchemaUtils.checkColumnNameDuplication(
+ (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema",
+ sparkSession.sessionState.conf.caseSensitiveAnalysis)
+
(dataSchema, partitionSchema)
}
@@ -339,22 +346,8 @@ case class DataSource(
case (format: FileFormat, _) =>
val allPaths = caseInsensitiveOptions.get("path") ++ paths
val hadoopConf = sparkSession.sessionState.newHadoopConf()
- val globbedPaths = allPaths.flatMap { path =>
- val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(hadoopConf)
- val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified)
-
- if (globPath.isEmpty) {
- throw new AnalysisException(s"Path does not exist: $qualified")
- }
- // Sufficient to check head of the globPath seq for non-glob scenario
- // Don't need to check once again if files exist in streaming mode
- if (checkFilesExist && !fs.exists(globPath.head)) {
- throw new AnalysisException(s"Path does not exist: ${globPath.head}")
- }
- globPath
- }.toArray
+ val globbedPaths = allPaths.flatMap(
+ DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray
val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache)
@@ -430,6 +423,7 @@ case class DataSource(
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
+ ifPartitionNotExists = false,
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
@@ -481,7 +475,7 @@ case class DataSource(
}
}
-object DataSource {
+object DataSource extends Logging {
/** A map to maintain backward compatibility in case we move data sources around. */
private val backwardCompatibilityMap: Map[String, String] = {
@@ -570,10 +564,19 @@ object DataSource {
// there is exactly one registered alias
head.getClass
case sources =>
- // There are multiple registered aliases for the input
- sys.error(s"Multiple sources found for $provider1 " +
- s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
- "please specify the fully qualified class name.")
+ // There are multiple registered aliases for the input. If there is single datasource
+ // that has "org.apache.spark" package in the prefix, we use it considering it is an
+ // internal datasource within Spark.
+ val sourceNames = sources.map(_.getClass.getName)
+ val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark"))
+ if (internalSources.size == 1) {
+ logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " +
+ s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).")
+ internalSources.head.getClass
+ } else {
+ throw new AnalysisException(s"Multiple sources found for $provider1 " +
+ s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.")
+ }
}
} catch {
case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] =>
@@ -600,4 +603,28 @@ object DataSource {
CatalogStorageFormat.empty.copy(
locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath)
}
+
+ /**
+ * If `path` is a file pattern, return all the files that match it. Otherwise, return itself.
+ * If `checkFilesExist` is `true`, also check the file existence.
+ */
+ private def checkAndGlobPathIfNecessary(
+ hadoopConf: Configuration,
+ path: String,
+ checkFilesExist: Boolean): Seq[Path] = {
+ val hdfsPath = new Path(path)
+ val fs = hdfsPath.getFileSystem(hadoopConf)
+ val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified)
+
+ if (globPath.isEmpty) {
+ throw new AnalysisException(s"Path does not exist: $qualified")
+ }
+ // Sufficient to check head of the globPath seq for non-glob scenario
+ // Don't need to check once again if files exist in streaming mode
+ if (checkFilesExist && !fs.exists(globPath.head)) {
+ throw new AnalysisException(s"Path does not exist: ${globPath.head}")
+ }
+ globPath
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2d83d512e702d..4ec8d55bf0c5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,10 +24,10 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String
* Note that, this rule must be run after `PreprocessTableCreation` and
* `PreprocessTableInsertion`.
*/
-case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
+case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
def resolver: Resolver = conf.resolver
@@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
val potentialSpecs = staticPartitions.filter {
case (partKey, partValue) => resolver(field.name, partKey)
}
- if (potentialSpecs.size == 0) {
+ if (potentialSpecs.isEmpty) {
None
} else if (potentialSpecs.size == 1) {
val partValue = potentialSpecs.head._2
- Some(Alias(Cast(Literal(partValue), field.dataType), field.name)())
+ Some(Alias(cast(Literal(partValue), field.dataType), field.name)())
} else {
throw new AnalysisException(
s"Partition column ${field.name} have multiple values specified, " +
@@ -142,8 +142,8 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
parts, query, overwrite, false) if parts.isEmpty =>
InsertIntoDataSourceCommand(l, query, overwrite)
- case InsertIntoTable(
- l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) =>
+ case i @ InsertIntoTable(
+ l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, _) =>
// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
@@ -195,6 +195,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitions,
+ i.ifPartitionNotExists,
partitionSchema,
t.bucketSpec,
t.fileFormat,
@@ -208,15 +209,16 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
/**
- * Replaces [[CatalogRelation]] with data source table if its table provider is not hive.
+ * Replaces [[UnresolvedCatalogRelation]] with concrete relation logical plans.
+ *
+ * TODO: we should remove the special handling for hive tables after completely making hive as a
+ * data source.
*/
class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {
- private def readDataSourceTable(r: CatalogRelation): LogicalPlan = {
- val table = r.tableMeta
+ private def readDataSourceTable(table: CatalogTable): LogicalPlan = {
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
- val cache = sparkSession.sessionState.catalog.tableRelationCache
-
- val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() {
+ val catalog = sparkSession.sessionState.catalog
+ catalog.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
val dataSource =
@@ -233,24 +235,30 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table)
}
- }).asInstanceOf[LogicalRelation]
+ })
+ }
- if (r.output.isEmpty) {
- // It's possible that the table schema is empty and need to be inferred at runtime. For this
- // case, we don't need to change the output of the cached plan.
- plan
- } else {
- plan.copy(output = r.output)
- }
+ private def readHiveTable(table: CatalogTable): LogicalPlan = {
+ HiveTableRelation(
+ table,
+ // Hive table columns are always nullable.
+ table.dataSchema.asNullable.toAttributes,
+ table.partitionSchema.asNullable.toAttributes)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i @ InsertIntoTable(r: CatalogRelation, _, _, _, _)
- if DDLUtils.isDatasourceTable(r.tableMeta) =>
- i.copy(table = readDataSourceTable(r))
+ case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _)
+ if DDLUtils.isDatasourceTable(tableMeta) =>
+ i.copy(table = readDataSourceTable(tableMeta))
+
+ case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) =>
+ i.copy(table = readHiveTable(tableMeta))
- case r: CatalogRelation if DDLUtils.isDatasourceTable(r.tableMeta) =>
- readDataSourceTable(r)
+ case UnresolvedCatalogRelation(tableMeta) if DDLUtils.isDatasourceTable(tableMeta) =>
+ readDataSourceTable(tableMeta)
+
+ case UnresolvedCatalogRelation(tableMeta) =>
+ readHiveTable(tableMeta)
}
}
@@ -258,7 +266,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
-object DataSourceStrategy extends Strategy with Logging {
+case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport {
+ import DataSourceStrategy._
+
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
pruneFilterProjectRaw(
@@ -298,7 +308,7 @@ object DataSourceStrategy extends Strategy with Logging {
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
- mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+ mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
@@ -436,7 +446,9 @@ object DataSourceStrategy extends Strategy with Logging {
private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = {
toCatalystRDD(relation, relation.output, rdd)
}
+}
+object DataSourceStrategy {
/**
* Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
*
@@ -492,7 +504,19 @@ object DataSourceStrategy extends Strategy with Logging {
Some(sources.IsNotNull(a.name))
case expressions.And(left, right) =>
- (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And)
+ // See SPARK-12218 for detailed discussion
+ // It is not safe to just convert one side if we do not understand the
+ // other side. Here is an example used to explain the reason.
+ // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0)
+ // and we do not understand how to convert trim(b) = 'blah'.
+ // If we only convert a = 2, we will end up with
+ // (a = 2) OR (c > 0), which will generate wrong results.
+ // Pushing one leg of AND down is only safe to do at the top level.
+ // You can see ParquetFilters' createFilter for more details.
+ for {
+ leftFilter <- translateFilter(left)
+ rightFilter <- translateFilter(right)
+ } yield sources.And(leftFilter, rightFilter)
case expressions.Or(left, right) =>
for {
@@ -527,8 +551,8 @@ object DataSourceStrategy extends Strategy with Logging {
* all [[Filter]]s that are completely filtered at the DataSource.
*/
protected[sql] def selectFilters(
- relation: BaseRelation,
- predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
+ relation: BaseRelation,
+ predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
// For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
// called `predicate`s, while all data source filters of type `sources.Filter` are simply called
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 4ec09bff429c5..e87cf8d0f84ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -111,9 +111,11 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
- val allColumns = queryExecution.logical.output
+ // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema
+ // names' case.
+ val allColumns = queryExecution.analyzed.output
val partitionSet = AttributeSet(partitionColumns)
- val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
+ val dataColumns = allColumns.filterNot(partitionSet.contains)
val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
@@ -170,8 +172,13 @@ object FileFormatWriter extends Logging {
val rdd = if (orderingMatched) {
queryExecution.toRdd
} else {
+ // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
+ // the physical plan may have different attribute ids due to optimizer removing some
+ // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
+ val orderingExpr = requiredOrdering
+ .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns))
SortExec(
- requiredOrdering.map(SortOrder(_, Ascending)),
+ orderingExpr,
global = false,
child = queryExecution.executedPlan).execute()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
index 9a08524476baa..89d8a85a9cbd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources
+import java.util.Locale
+
import scala.collection.mutable
import org.apache.spark.sql.{SparkSession, SQLContext}
@@ -50,15 +52,22 @@ case class HadoopFsRelation(
override def sqlContext: SQLContext = sparkSession.sqlContext
- val schema: StructType = {
- val getColName: (StructField => String) =
- if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase
- val overlappedPartCols = mutable.Map.empty[String, StructField]
- partitionSchema.foreach { partitionField =>
- if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
- overlappedPartCols += getColName(partitionField) -> partitionField
- }
+ private def getColName(f: StructField): String = {
+ if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
+ f.name
+ } else {
+ f.name.toLowerCase(Locale.ROOT)
+ }
+ }
+
+ val overlappedPartCols = mutable.Map.empty[String, StructField]
+ partitionSchema.foreach { partitionField =>
+ if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
+ overlappedPartCols += getColName(partitionField) -> partitionField
}
+ }
+
+ val schema: StructType = {
StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
index 9897ab73b0da8..91e31650617ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.HiveCatalogMetrics
+import org.apache.spark.sql.execution.streaming.FileStreamSink
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
@@ -36,20 +37,28 @@ import org.apache.spark.util.SerializableConfiguration
* A [[FileIndex]] that generates the list of files to process by recursively listing all the
* files present in `paths`.
*
- * @param rootPaths the list of root table paths to scan
+ * @param rootPathsSpecified the list of root table paths to scan (some of which might be
+ * filtered out later)
* @param parameters as set of options to control discovery
* @param partitionSchema an optional partition schema that will be use to provide types for the
* discovered partitions
*/
class InMemoryFileIndex(
sparkSession: SparkSession,
- override val rootPaths: Seq[Path],
+ rootPathsSpecified: Seq[Path],
parameters: Map[String, String],
partitionSchema: Option[StructType],
fileStatusCache: FileStatusCache = NoopCache)
extends PartitioningAwareFileIndex(
sparkSession, parameters, partitionSchema, fileStatusCache) {
+ // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir)
+ // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain
+ // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath"
+ // is the output of a streaming query.
+ override val rootPaths =
+ rootPathsSpecified.filterNot(FileStreamSink.ancestorIsMetadataDirectory(_, hadoopConf))
+
@volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _
@volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _
@volatile private var cachedPartitionSpec: PartitionSpec = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 19b51d4d9530a..c9d31449d3629 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -37,10 +37,13 @@ import org.apache.spark.sql.execution.command._
* overwrites: when the spec is empty, all partitions are overwritten.
* When it covers a prefix of the partition keys, only partitions matching
* the prefix are overwritten.
+ * @param ifPartitionNotExists If true, only write if the partition does not exist.
+ * Only valid for static partitions.
*/
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
+ ifPartitionNotExists: Boolean,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
@@ -61,8 +64,8 @@ case class InsertIntoHadoopFsRelationCommand(
val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => "\"" + x + "\""
}.mkString(", ")
- throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
- s"cannot save to file.")
+ throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " +
+ "cannot save to file.")
}
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
@@ -76,11 +79,12 @@ case class InsertIntoHadoopFsRelationCommand(
var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
+ var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty
// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
- val matchingPartitions = sparkSession.sessionState.catalog.listPartitions(
+ matchingPartitions = sparkSession.sessionState.catalog.listPartitions(
catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
@@ -101,8 +105,12 @@ case class InsertIntoHadoopFsRelationCommand(
case (SaveMode.ErrorIfExists, true) =>
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
case (SaveMode.Overwrite, true) =>
- deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
- true
+ if (ifPartitionNotExists && matchingPartitions.nonEmpty) {
+ false
+ } else {
+ deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)
+ true
+ }
case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
case (SaveMode.Ignore, exists) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index ffd7f6c750f85..6b6f6388d54e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -177,7 +177,7 @@ abstract class PartitioningAwareFileIndex(
})
val selected = partitions.filter {
- case PartitionPath(values, _) => boundPredicate(values)
+ case PartitionPath(values, _) => boundPredicate.eval(values)
}
logInfo {
val total = partitions.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index c3583209efc56..6f7438192dfe2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -94,7 +94,7 @@ object PartitioningUtils {
typeInference: Boolean,
basePaths: Set[Path],
timeZoneId: String): PartitionSpec = {
- parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId))
+ parsePartitions(paths, typeInference, basePaths, DateTimeUtils.getTimeZone(timeZoneId))
}
private[datasources] def parsePartitions(
@@ -138,7 +138,7 @@ object PartitioningUtils {
"root directory of the table. If there are multiple root directories, " +
"please load them separately and then union them.")
- val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues)
+ val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, timeZone)
// Creates the StructType which represents the partition columns.
val fields = {
@@ -243,7 +243,7 @@ object PartitioningUtils {
if (equalSignIndex == -1) {
None
} else {
- val columnName = columnSpec.take(equalSignIndex)
+ val columnName = unescapePathName(columnSpec.take(equalSignIndex))
assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'")
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
@@ -322,7 +322,8 @@ object PartitioningUtils {
* }}}
*/
def resolvePartitions(
- pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = {
+ pathsWithPartitionValues: Seq[(Path, PartitionValues)],
+ timeZone: TimeZone): Seq[PartitionValues] = {
if (pathsWithPartitionValues.isEmpty) {
Seq.empty
} else {
@@ -337,7 +338,7 @@ object PartitioningUtils {
val values = pathsWithPartitionValues.map(_._2)
val columnCount = values.head.columnNames.size
val resolvedValues = (0 until columnCount).map { i =>
- resolveTypeConflicts(values.map(_.literals(i)))
+ resolveTypeConflicts(values.map(_.literals(i)), timeZone)
}
// Fills resolved literals back to each partition
@@ -474,7 +475,7 @@ object PartitioningUtils {
* Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower"
* types.
*/
- private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = {
+ private def resolveTypeConflicts(literals: Seq[Literal], timeZone: TimeZone): Seq[Literal] = {
val desiredType = {
val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_))
// Falls back to string if all values of this column are null or empty string
@@ -482,7 +483,7 @@ object PartitioningUtils {
}
literals.map { case l @ Literal(_, dataType) =>
- Literal.create(Cast(l, desiredType).eval(), desiredType)
+ Literal.create(Cast(l, desiredType, Some(timeZone.getID)).eval(), desiredType)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 905b8683e10bd..f5df1848a38c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources
+import org.apache.spark.sql.catalyst.catalog.CatalogStatistics
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
@@ -59,8 +60,11 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation =
fsRelation.copy(location = prunedFileIndex)(sparkSession)
- val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation)
-
+ // Change table stats based on the sizeInBytes of pruned files
+ val withStats = logicalRelation.catalogTable.map(_.copy(
+ stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes)))))
+ val prunedLogicalRelation = logicalRelation.copy(
+ relation = prunedFsRelation, catalogTable = withStats)
// Keep partition-pruning predicates so that they are visible in physical planning
val filterExpression = filters.reduceLeft(And)
val filter = Filter(filterExpression, prunedLogicalRelation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index 6f19ea195c0cd..b92684c5d3807 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -49,4 +49,11 @@ case class SaveIntoDataSourceCommand(
Seq.empty[Row]
}
+
+ override def simpleString: String = {
+ val redacted = SparkSession.getActiveSession
+ .map(_.sessionState.conf.redactOptions(options))
+ .getOrElse(Map())
+ s"SaveIntoDataSourceCommand ${provider}, ${redacted}, ${mode}"
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 83bdf6fe224be..2de58384f9834 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -111,8 +111,8 @@ abstract class CSVDataSource extends Serializable {
object CSVDataSource {
def apply(options: CSVOptions): CSVDataSource = {
- if (options.wholeFile) {
- WholeFileCSVDataSource
+ if (options.multiLine) {
+ MultiLineCSVDataSource
} else {
TextInputCSVDataSource
}
@@ -196,7 +196,7 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}
-object WholeFileCSVDataSource extends CSVDataSource {
+object MultiLineCSVDataSource extends CSVDataSource {
override val isSplitable: Boolean = false
override def readFile(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 62e4c6e4b4ea0..a13a5a34b4a84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -117,7 +117,7 @@ class CSVOptions(
name.map(CompressionCodecs.getCodecClassName)
}
- val timeZone: TimeZone = TimeZone.getTimeZone(
+ val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
@@ -128,7 +128,7 @@ class CSVOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
- val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
+ val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
val maxColumns = getInt("maxColumns", 20480)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
index 7a6c0f9fed2f9..1723596de1db2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
@@ -32,6 +32,13 @@ import org.apache.spark.util.Utils
*/
object DriverRegistry extends Logging {
+ /**
+ * Load DriverManager first to avoid any race condition between
+ * DriverManager static initialization block and specific driver class's
+ * static initialization block. e.g. PhoenixDriver
+ */
+ DriverManager.getDrivers
+
private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty
def register(className: String): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 591096d5efd22..96a8a51da18e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -97,10 +97,13 @@ class JDBCOptions(
val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong)
// the upper bound of the partition column
val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong)
- require(partitionColumn.isEmpty ||
- (lowerBound.isDefined && upperBound.isDefined && numPartitions.isDefined),
- s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," +
- s" and '$JDBC_NUM_PARTITIONS' are required.")
+ // numPartitions is also used for data source writing
+ require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) ||
+ (partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined &&
+ numPartitions.isDefined),
+ s"When reading JDBC data sources, users need to specify all or none for the following " +
+ s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " +
+ s"and '$JDBC_NUM_PARTITIONS'")
val fetchSize = {
val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt
require(size >= 0,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 2bdc43254133e..7097069b92b78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -286,7 +286,7 @@ private[jdbc] class JDBCRDD(
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
- dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap)
+ dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 8b45dba04d29e..272cb4a82641e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -64,7 +64,8 @@ private[sql] object JDBCRelation extends Logging {
s"bound. Lower bound: $lowerBound; Upper bound: $upperBound")
val numPartitions =
- if ((upperBound - lowerBound) >= partitioning.numPartitions) {
+ if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */
+ (upperBound - lowerBound) < 0) {
partitioning.numPartitions
} else {
logWarning("The number of partitions is reduced because the specified number of " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 74dcfb06f5c2b..37e7bb0a59bb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -29,6 +29,8 @@ class JdbcRelationProvider extends CreatableRelationProvider
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
+ import JDBCOptions._
+
val jdbcOptions = new JDBCOptions(parameters)
val partitionColumn = jdbcOptions.partitionColumn
val lowerBound = jdbcOptions.lowerBound
@@ -36,10 +38,13 @@ class JdbcRelationProvider extends CreatableRelationProvider
val numPartitions = jdbcOptions.numPartitions
val partitionInfo = if (partitionColumn.isEmpty) {
- assert(lowerBound.isEmpty && upperBound.isEmpty)
+ assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " +
+ s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
null
} else {
- assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty)
+ assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
+ s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
+ s"'$JDBC_NUM_PARTITIONS' are also required")
JDBCPartitioningInfo(
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 5fc3c2753b6cf..ce0610fc09394 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -440,8 +440,9 @@ object JdbcUtils extends Logging {
case StringType =>
(array: Object) =>
- array.asInstanceOf[Array[java.lang.String]]
- .map(UTF8String.fromString)
+ // some underlying types are not String such as uuid, inet, cidr, etc.
+ array.asInstanceOf[Array[java.lang.Object]]
+ .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
case DateType =>
(array: Object) =>
@@ -652,8 +653,17 @@ object JdbcUtils extends Logging {
case e: SQLException =>
val cause = e.getNextException
if (cause != null && e.getCause != cause) {
+ // If there is no cause already, set 'next exception' as cause. If cause is null,
+ // it *may* be because no cause was set yet
if (e.getCause == null) {
- e.initCause(cause)
+ try {
+ e.initCause(cause)
+ } catch {
+ // Or it may be null because the cause *was* explicitly initialized, to *null*,
+ // in which case this fails. There is no other way to detect it.
+ // addSuppressed in this case as well.
+ case _: IllegalStateException => e.addSuppressed(cause)
+ }
} else {
e.addSuppressed(cause)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 4f2963da9ace9..5a92a71d19e78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -86,8 +86,8 @@ abstract class JsonDataSource extends Serializable {
object JsonDataSource {
def apply(options: JSONOptions): JsonDataSource = {
- if (options.wholeFile) {
- WholeFileJsonDataSource
+ if (options.multiLine) {
+ MultiLineJsonDataSource
} else {
TextInputJsonDataSource
}
@@ -147,7 +147,7 @@ object TextInputJsonDataSource extends JsonDataSource {
}
}
-object WholeFileJsonDataSource extends JsonDataSource {
+object MultiLineJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index fb632cf2bb70e..09879690e5f98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -326,8 +326,8 @@ private[sql] object JsonInferSchema {
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// The case that given `DecimalType` is capable of given `IntegralType` is handled in
- // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
- // the given `DecimalType` is not capable of the given `IntegralType`.
+ // `findTightestCommonType`. Both cases below will be executed only when the given
+ // `DecimalType` is not capable of the given `IntegralType`.
case (t1: IntegralType, t2: DecimalType) =>
compatibleType(DecimalType.forType(t1), t2)
case (t1: DecimalType, t2: IntegralType) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 2f3a2c62b912c..1d60495fe1dc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
class ParquetFileFormat
extends FileFormat
@@ -85,7 +85,7 @@ class ParquetFileFormat
conf.getClass(
SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key,
classOf[ParquetOutputCommitter],
- classOf[ParquetOutputCommitter])
+ classOf[OutputCommitter])
if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) {
logInfo("Using default output committer for Parquet: " +
@@ -97,7 +97,7 @@ class ParquetFileFormat
conf.setClass(
SQLConf.OUTPUT_COMMITTER_CLASS.key,
committerClass,
- classOf[ParquetOutputCommitter])
+ classOf[OutputCommitter])
// We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override
// it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why
@@ -137,6 +137,14 @@ class ParquetFileFormat
conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false)
}
+ if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false)
+ && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) {
+ // output summary is requested, but the class is not a Parquet Committer
+ logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" +
+ s" create job summaries. " +
+ s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.")
+ }
+
new OutputWriterFactory {
// This OutputWriterFactory instance is deserialized when writing Parquet files on the
// executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold
@@ -479,24 +487,29 @@ object ParquetFileFormat extends Logging {
partFiles: Seq[FileStatus],
ignoreCorruptFiles: Boolean): Seq[Footer] = {
val parFiles = partFiles.par
- parFiles.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool(8))
- parFiles.flatMap { currentFile =>
- try {
- // Skips row group information since we only need the schema.
- // ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
- // when it can't read the footer.
- Some(new Footer(currentFile.getPath(),
- ParquetFileReader.readFooter(
- conf, currentFile, SKIP_ROW_GROUPS)))
- } catch { case e: RuntimeException =>
- if (ignoreCorruptFiles) {
- logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
- None
- } else {
- throw new IOException(s"Could not read footer for file: $currentFile", e)
+ val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8)
+ parFiles.tasksupport = new ForkJoinTaskSupport(pool)
+ try {
+ parFiles.flatMap { currentFile =>
+ try {
+ // Skips row group information since we only need the schema.
+ // ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
+ // when it can't read the footer.
+ Some(new Footer(currentFile.getPath(),
+ ParquetFileReader.readFooter(
+ conf, currentFile, SKIP_ROW_GROUPS)))
+ } catch { case e: RuntimeException =>
+ if (ignoreCorruptFiles) {
+ logWarning(s"Skipped the footer in the corrupted file: $currentFile", e)
+ None
+ } else {
+ throw new IOException(s"Could not read footer for file: $currentFile", e)
+ }
}
- }
- }.seq
+ }.seq
+ } finally {
+ pool.shutdown()
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index a6a6cef5861f3..763841efbd9f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -166,7 +166,14 @@ private[parquet] object ParquetFilters {
* Converts data sources filters to Parquet filter predicates.
*/
def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = {
- val dataTypeOf = getFieldMap(schema)
+ val nameToType = getFieldMap(schema)
+
+ // Parquet does not allow dots in the column name because dots are used as a column path
+ // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates
+ // with missing columns. The incorrect results could be got from Parquet when we push down
+ // filters for the column having dots in the names. Thus, we do not push down such filters.
+ // See SPARK-20364.
+ def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".")
// NOTE:
//
@@ -184,30 +191,30 @@ private[parquet] object ParquetFilters {
// Probably I missed something and obviously this should be changed.
predicate match {
- case sources.IsNull(name) if dataTypeOf.contains(name) =>
- makeEq.lift(dataTypeOf(name)).map(_(name, null))
- case sources.IsNotNull(name) if dataTypeOf.contains(name) =>
- makeNotEq.lift(dataTypeOf(name)).map(_(name, null))
-
- case sources.EqualTo(name, value) if dataTypeOf.contains(name) =>
- makeEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) =>
- makeNotEq.lift(dataTypeOf(name)).map(_(name, value))
-
- case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) =>
- makeEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) =>
- makeNotEq.lift(dataTypeOf(name)).map(_(name, value))
-
- case sources.LessThan(name, value) if dataTypeOf.contains(name) =>
- makeLt.lift(dataTypeOf(name)).map(_(name, value))
- case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) =>
- makeLtEq.lift(dataTypeOf(name)).map(_(name, value))
-
- case sources.GreaterThan(name, value) if dataTypeOf.contains(name) =>
- makeGt.lift(dataTypeOf(name)).map(_(name, value))
- case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) =>
- makeGtEq.lift(dataTypeOf(name)).map(_(name, value))
+ case sources.IsNull(name) if canMakeFilterOn(name) =>
+ makeEq.lift(nameToType(name)).map(_(name, null))
+ case sources.IsNotNull(name) if canMakeFilterOn(name) =>
+ makeNotEq.lift(nameToType(name)).map(_(name, null))
+
+ case sources.EqualTo(name, value) if canMakeFilterOn(name) =>
+ makeEq.lift(nameToType(name)).map(_(name, value))
+ case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) =>
+ makeNotEq.lift(nameToType(name)).map(_(name, value))
+
+ case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) =>
+ makeEq.lift(nameToType(name)).map(_(name, value))
+ case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) =>
+ makeNotEq.lift(nameToType(name)).map(_(name, value))
+
+ case sources.LessThan(name, value) if canMakeFilterOn(name) =>
+ makeLt.lift(nameToType(name)).map(_(name, value))
+ case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) =>
+ makeLtEq.lift(nameToType(name)).map(_(name, value))
+
+ case sources.GreaterThan(name, value) if canMakeFilterOn(name) =>
+ makeGt.lift(nameToType(name)).map(_(name, value))
+ case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) =>
+ makeGtEq.lift(nameToType(name)).map(_(name, value))
case sources.And(lhs, rhs) =>
// At here, it is not safe to just convert one side if we do not understand the
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 38b0e33937f3c..63a8666f0d774 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -58,7 +58,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
private var schema: StructType = _
// `ValueWriter`s for all fields of the schema
- private var rootFieldWriters: Seq[ValueWriter] = _
+ private var rootFieldWriters: Array[ValueWriter] = _
// The Parquet `RecordConsumer` to which all `InternalRow`s are written
private var recordConsumer: RecordConsumer = _
@@ -90,7 +90,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
}
- this.rootFieldWriters = schema.map(_.dataType).map(makeWriter)
+ this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter]
val messageType = new ParquetSchemaConverter(configuration).convert(schema)
val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava
@@ -116,7 +116,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
}
private def writeFields(
- row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = {
+ row: InternalRow, schema: StructType, fieldWriters: Array[ValueWriter]): Unit = {
var i = 0
while (i < row.numFields) {
if (!row.isNullAt(i)) {
@@ -192,7 +192,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
makeDecimalWriter(precision, scale)
case t: StructType =>
- val fieldWriters = t.map(_.dataType).map(makeWriter)
+ val fieldWriters = t.map(_.dataType).map(makeWriter).toArray[ValueWriter]
(row: SpecializedGetters, ordinal: Int) =>
consumeGroup {
writeFields(row.getStruct(ordinal, t.length), t, fieldWriters)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 7abf2ae5166b5..9647f2c0edccb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
@@ -127,11 +127,11 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
val resolver = sparkSession.sessionState.conf.resolver
val tableCols = existingTable.schema.map(_.name)
- // As we are inserting into an existing table, we should respect the existing schema and
- // adjust the column order of the given dataframe according to it, or throw exception
- // if the column names do not match.
+ // As we are inserting into an existing table, we should respect the existing schema, preserve
+ // the case and adjust the column order of the given DataFrame according to it, or throw
+ // an exception if the column names do not match.
val adjustedColumns = tableCols.map { col =>
- query.resolve(Seq(col), resolver).getOrElse {
+ query.resolve(Seq(col), resolver).map(Alias(_, col)()).getOrElse {
val inputColumns = query.schema.map(_.name).mkString(", ")
throw new AnalysisException(
s"cannot resolve '$col' given input columns: [$inputColumns]")
@@ -168,15 +168,9 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
""".stripMargin)
}
- val newQuery = if (adjustedColumns != query.output) {
- Project(adjustedColumns, query)
- } else {
- query
- }
-
c.copy(
tableDesc = existingTable,
- query = Some(newQuery))
+ query = Some(Project(adjustedColumns, query)))
// Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity
// config, and do various checks:
@@ -315,7 +309,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
* table. It also does data type casting and field renaming, to make sure that the columns to be
* inserted have the correct data type and fields have the correct names.
*/
-case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
+case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
private def preprocess(
insert: InsertIntoTable,
tblName: String,
@@ -367,7 +361,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
// Renaming is needed for handling the following cases like
// 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
// 2) Target tables have column metadata
- Alias(Cast(actual, expected.dataType), expected.name)(
+ Alias(cast(actual, expected.dataType), expected.name)(
explicitMetadata = Option(expected.metadata))
}
}
@@ -382,7 +376,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved =>
table match {
- case relation: CatalogRelation =>
+ case relation: HiveTableRelation =>
val metadata = relation.tableMeta
preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames)
case LogicalRelation(h: HadoopFsRelation, _, catalogTable) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index d993ea6c6cef9..4b52f3e4c49b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
@@ -58,6 +59,24 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
child.executeBroadcast()
}
+
+ // `ReusedExchangeExec` can have distinct set of output attribute ids from its child, we need
+ // to update the attribute ids in `outputPartitioning` and `outputOrdering`.
+ private lazy val updateAttr: Expression => Expression = {
+ val originalAttrToNewAttr = AttributeMap(child.output.zip(output))
+ e => e.transform {
+ case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr)
+ }
+ }
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning match {
+ case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr))
+ case other => other
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder])
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index f06544ea8ed04..eebe6ad2e7944 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -40,6 +40,9 @@ case class ShuffleExchange(
child: SparkPlan,
@transient coordinator: Option[ExchangeCoordinator]) extends Exchange {
+ // NOTE: coordinator can be null after serialization/deserialization,
+ // e.g. it can be null on the Executor side
+
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"))
@@ -47,7 +50,7 @@ case class ShuffleExchange(
val extraInfo = coordinator match {
case Some(exchangeCoordinator) =>
s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})"
- case None => ""
+ case _ => ""
}
val simpleNodeName = "Exchange"
@@ -70,7 +73,7 @@ case class ShuffleExchange(
// the plan.
coordinator match {
case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
- case None =>
+ case _ =>
}
}
@@ -117,7 +120,7 @@ case class ShuffleExchange(
val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
shuffleRDD
- case None =>
+ case _ =>
val shuffleDependency = prepareShuffleDependency()
preparePostShuffleRDD(shuffleDependency)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 0bc261d593df4..69715ab1f675f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -257,8 +257,8 @@ case class BroadcastHashJoinExec(
s"""
|boolean $conditionPassed = true;
|${eval.trim}
- |${ev.code}
|if ($matched != null) {
+ | ${ev.code}
| $conditionPassed = !${ev.isNull} && ${ev.value};
|}
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index f380986951317..4d261dd422bc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -35,11 +35,12 @@ class UnsafeCartesianRDD(
left : RDD[UnsafeRow],
right : RDD[UnsafeRow],
numFieldsOfRight: Int,
+ inMemoryBufferThreshold: Int,
spillThreshold: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
- val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold)
val partition = split.asInstanceOf[CartesianPartition]
rdd2.iterator(partition.s2, context).foreach(rowArray.add)
@@ -71,9 +72,12 @@ case class CartesianProductExec(
val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
- val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
-
- val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
+ val pair = new UnsafeCartesianRDD(
+ leftResults,
+ rightResults,
+ right.output.size,
+ sqlContext.conf.cartesianProductExecBufferInMemoryThreshold,
+ sqlContext.conf.cartesianProductExecBufferSpillThreshold)
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 2dd1dc3da96c9..78190bf117281 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -533,7 +533,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
def append(key: Long, row: UnsafeRow): Unit = {
val sizeInBytes = row.getSizeInBytes
if (sizeInBytes >= (1 << SIZE_BITS)) {
- sys.error("Does not support row that is larger than 256M")
+ throw new UnsupportedOperationException("Does not support row that is larger than 256M")
}
if (key < minKey) {
@@ -543,19 +543,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = key
}
- // There is 8 bytes for the pointer to next value
- if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
- val used = page.length
- if (used >= (1 << 30)) {
- sys.error("Can not build a HashedRelation that is larger than 8G")
- }
- ensureAcquireMemory(used * 8L * 2)
- val newPage = new Array[Long](used * 2)
- Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
- cursor - Platform.LONG_ARRAY_OFFSET)
- page = newPage
- freeMemory(used * 8L)
- }
+ grow(row.getSizeInBytes)
// copy the bytes of UnsafeRow
val offset = cursor
@@ -588,7 +576,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
growArray()
} else if (numKeys > array.length / 2 * 0.75) {
// The fill ratio should be less than 0.75
- sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
+ throw new UnsupportedOperationException(
+ "Cannot build HashedRelation with more than 1/3 billions unique keys")
}
}
} else {
@@ -599,6 +588,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
}
+ private def grow(inputRowSize: Int): Unit = {
+ // There is 8 bytes for the pointer to next value
+ val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
+ if (neededNumWords > page.length) {
+ if (neededNumWords > (1 << 30)) {
+ throw new UnsupportedOperationException(
+ "Can not build a HashedRelation that is larger than 8G")
+ }
+ val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
+ ensureAcquireMemory(newNumWords * 8L)
+ val newPage = new Array[Long](newNumWords.toInt)
+ Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
+ cursor - Platform.LONG_ARRAY_OFFSET)
+ val used = page.length
+ page = newPage
+ freeMemory(used * 8L)
+ }
+ }
+
private def growArray(): Unit = {
var old_array = array
val n = array.length
@@ -733,6 +741,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
array = readLongArray(readBuffer, length)
val pageLength = readLong().toInt
page = readLongArray(readBuffer, pageLength)
+ // Restore cursor variable to make this map able to be serialized again on executors.
+ cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET
}
override def readExternal(in: ObjectInput): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index c6aae1a4db2e4..70dada8b63ae9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -82,7 +82,7 @@ case class SortMergeJoinExec(
override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
- case Inner =>
+ case _: InnerLike =>
val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
@@ -130,9 +130,14 @@ case class SortMergeJoinExec(
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
}
+ private def getInMemoryThreshold: Int = {
+ sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
@@ -158,6 +163,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -201,6 +207,7 @@ case class SortMergeJoinExec(
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
bufferedIter = RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
val rightNullRow = new GenericInternalRow(right.output.length)
@@ -214,6 +221,7 @@ case class SortMergeJoinExec(
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
bufferedIter = RowIterator.fromScala(leftIter),
+ inMemoryThreshold,
spillThreshold
)
val leftNullRow = new GenericInternalRow(left.output.length)
@@ -247,6 +255,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -281,6 +290,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -290,6 +300,7 @@ case class SortMergeJoinExec(
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
if (currentRightMatches == null || currentRightMatches.length == 0) {
+ numOutputRows += 1
return true
}
var found = false
@@ -321,6 +332,7 @@ case class SortMergeJoinExec(
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
+ inMemoryThreshold,
spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -371,6 +383,7 @@ case class SortMergeJoinExec(
keys: Seq[Expression],
input: Seq[Attribute]): Seq[ExprCode] = {
ctx.INPUT_ROW = row
+ ctx.currentVars = null
keys.map(BindReferences.bindReference(_, input).genCode(ctx))
}
@@ -418,8 +431,10 @@ case class SortMergeJoinExec(
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
- ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
+ ctx.addMutableState(clsName, matches,
+ s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);")
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
@@ -624,6 +639,9 @@ case class SortMergeJoinExec(
* @param streamedIter an input whose rows will be streamed.
* @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
* have the same join key.
+ * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by
+ * internal buffer
+ * @param spillThreshold Threshold for number of rows to be spilled by internal buffer
*/
private[joins] class SortMergeJoinScanner(
streamedKeyGenerator: Projection,
@@ -631,7 +649,8 @@ private[joins] class SortMergeJoinScanner(
keyOrdering: Ordering[InternalRow],
streamedIter: RowIterator,
bufferedIter: RowIterator,
- bufferThreshold: Int) {
+ inMemoryThreshold: Int,
+ spillThreshold: Int) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
@@ -642,7 +661,8 @@ private[joins] class SortMergeJoinScanner(
*/
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
- private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
+ private[this] val bufferedMatches =
+ new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 48c7b80bffe03..3643ef3497cfe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -361,8 +362,11 @@ object MapGroupsExec {
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
+ timeoutConf: GroupStateTimeout,
child: SparkPlan): MapGroupsExec = {
- val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None))
+ val f = (key: Any, values: Iterator[Any]) => {
+ func(key, values, GroupStateImpl.createForBatch(timeoutConf))
+ }
new MapGroupsExec(f, keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, outputObjAttr, child)
}
@@ -393,7 +397,11 @@ case class FlatMapGroupsInRExec(
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(groupingAttributes) :: Nil
+ if (groupingAttributes.isEmpty) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingAttributes) :: Nil
+ }
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala
index 408c8f81f17ba..77bc0ba5548dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala
@@ -169,13 +169,15 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
*/
private def compact(batchId: Long, logs: Array[T]): Boolean = {
val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval)
- val allLogs = validBatches.flatMap(batchId => super.get(batchId)).flatten ++ logs
- if (super.add(batchId, compactLogs(allLogs).toArray)) {
- true
- } else {
- // Return false as there is another writer.
- false
- }
+ val allLogs = validBatches.map { id =>
+ super.get(id).getOrElse {
+ throw new IllegalStateException(
+ s"${batchIdToPath(id)} doesn't exist when compacting batch $batchId " +
+ s"(compactInterval: $compactInterval)")
+ }
+ }.flatten ++ logs
+ // Return false as there is another writer.
+ super.add(batchId, compactLogs(allLogs).toArray)
}
/**
@@ -190,7 +192,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag](
if (latestId >= 0) {
try {
val logs =
- getAllValidBatches(latestId, compactInterval).flatMap(id => super.get(id)).flatten
+ getAllValidBatches(latestId, compactInterval).map { id =>
+ super.get(id).getOrElse {
+ throw new IllegalStateException(
+ s"${batchIdToPath(id)} doesn't exist " +
+ s"(latestId: $latestId, compactInterval: $compactInterval)")
+ }
+ }.flatten
return compactLogs(logs).toArray
} catch {
case e: IOException =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 25cf609fc336e..55e7508b2ed29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -27,27 +27,25 @@ import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.AccumulatorV2
/** Class for collecting event time stats with an accumulator */
-case class EventTimeStats(var max: Long, var min: Long, var sum: Long, var count: Long) {
+case class EventTimeStats(var max: Long, var min: Long, var avg: Double, var count: Long) {
def add(eventTime: Long): Unit = {
this.max = math.max(this.max, eventTime)
this.min = math.min(this.min, eventTime)
- this.sum += eventTime
this.count += 1
+ this.avg += (eventTime - avg) / count
}
def merge(that: EventTimeStats): Unit = {
this.max = math.max(this.max, that.max)
this.min = math.min(this.min, that.min)
- this.sum += that.sum
this.count += that.count
+ this.avg += (that.avg - this.avg) * that.count / this.count
}
-
- def avg: Long = sum / count
}
object EventTimeStats {
def zero: EventTimeStats = EventTimeStats(
- max = Long.MinValue, min = Long.MaxValue, sum = 0L, count = 0L)
+ max = Long.MinValue, min = Long.MaxValue, avg = 0.0, count = 0L)
}
/** Accumulator that collects stats on event time in a batch. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index 07ec4e9429e42..397be27337e42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -42,9 +42,11 @@ object FileStreamSink extends Logging {
try {
val hdfsPath = new Path(singlePath)
val fs = hdfsPath.getFileSystem(hadoopConf)
- val metadataPath = new Path(hdfsPath, metadataDir)
- val res = fs.exists(metadataPath)
- res
+ if (fs.isDirectory(hdfsPath)) {
+ fs.exists(new Path(hdfsPath, metadataDir))
+ } else {
+ false
+ }
} catch {
case NonFatal(e) =>
logWarning(s"Error while looking for metadata directory.")
@@ -53,6 +55,26 @@ object FileStreamSink extends Logging {
case _ => false
}
}
+
+ /**
+ * Returns true if the path is the metadata dir or its ancestor is the metadata dir.
+ * E.g.:
+ * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true
+ * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true
+ * - ancestorIsMetadataDirectory(/a/b/c) => false
+ */
+ def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = {
+ val fs = path.getFileSystem(hadoopConf)
+ var currentPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ while (currentPath != null) {
+ if (currentPath.getName == FileStreamSink.metadataDir) {
+ return true
+ } else {
+ currentPath = currentPath.getParent
+ }
+ }
+ return false
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
index 8d718b2164d22..c9939ac1db746 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.streaming
+import java.net.URI
+
import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
@@ -47,7 +49,8 @@ case class SinkFileStatus(
action: String) {
def toFileStatus: FileStatus = {
- new FileStatus(size, isDir, blockReplication, blockSize, modificationTime, new Path(path))
+ new FileStatus(
+ size, isDir, blockReplication, blockSize, modificationTime, new Path(new URI(path)))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index a9e64c640042a..509bf4840f433 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -47,8 +47,9 @@ class FileStreamSource(
private val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ @transient private val fs = new Path(path).getFileSystem(hadoopConf)
+
private val qualifiedBasePath: Path = {
- val fs = new Path(path).getFileSystem(hadoopConf)
fs.makeQualified(new Path(path)) // can contains glob patterns
}
@@ -187,7 +188,7 @@ class FileStreamSource(
if (SparkHadoopUtil.get.isGlobPath(new Path(path))) Some(false) else None
private def allFilesUsingInMemoryFileIndex() = {
- val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath)
+ val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(fs, qualifiedBasePath)
val fileIndex = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType))
fileIndex.allFiles()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala
index 33e6a1d5d6e18..8628471fdb925 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala
@@ -115,7 +115,10 @@ class FileStreamSourceLog(
Map.empty[Long, Option[Array[FileEntry]]]
}
- (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1)
+ val batches =
+ (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1)
+ HDFSMetadataLog.verifyBatchIds(batches.map(_._1), startId, endId)
+ batches
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index e42df5dd61c70..3ceb4cf84a413 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -120,7 +120,7 @@ case class FlatMapGroupsWithStateExec(
val filteredIter = watermarkPredicateForData match {
case Some(predicate) if timeoutConf == EventTimeTimeout =>
iter.filter(row => !predicate.eval(row))
- case None =>
+ case _ =>
iter
}
@@ -215,7 +215,7 @@ case class FlatMapGroupsWithStateExec(
val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = getStateObj(prevStateRowOption)
- val keyedState = new GroupStateImpl(
+ val keyedState = GroupStateImpl.createForStreaming(
stateObjOption,
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
@@ -230,6 +230,20 @@ case class FlatMapGroupsWithStateExec(
// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {
+
+ val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
+ // If the state has not yet been set but timeout has been set, then
+ // we have to generate a row to save the timeout. However, attempting serialize
+ // null using case class encoder throws -
+ // java.lang.NullPointerException: Null value appeared in non-nullable field:
+ // If the schema is inferred from a Scala tuple / case class, or a Java bean, please
+ // try to use scala.Option[_] or other nullable types.
+ if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
+ throw new IllegalStateException(
+ "Cannot set timeout when state is not defined, that is, state has not been" +
+ "initialized or has been removed")
+ }
+
if (keyedState.hasRemoved) {
store.remove(keyRow)
numUpdatedStateRows += 1
@@ -239,7 +253,6 @@ case class FlatMapGroupsWithStateExec(
case Some(row) => getTimeoutTimestamp(row)
case None => NO_TIMESTAMP
}
- val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
val stateRowToWrite = if (keyedState.hasUpdated) {
getStateRow(keyedState.get)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
index 148d92247d6f0..4401e86936af9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
@@ -38,20 +38,13 @@ import org.apache.spark.unsafe.types.CalendarInterval
* @param hasTimedOut Whether the key for which this state wrapped is being created is
* getting timed out or not.
*/
-private[sql] class GroupStateImpl[S](
+private[sql] class GroupStateImpl[S] private(
optionalValue: Option[S],
batchProcessingTimeMs: Long,
eventTimeWatermarkMs: Long,
timeoutConf: GroupStateTimeout,
override val hasTimedOut: Boolean) extends GroupState[S] {
- // Constructor to create dummy state when using mapGroupsWithState in a batch query
- def this(optionalValue: Option[S]) = this(
- optionalValue,
- batchProcessingTimeMs = NO_TIMESTAMP,
- eventTimeWatermarkMs = NO_TIMESTAMP,
- timeoutConf = GroupStateTimeout.NoTimeout,
- hasTimedOut = false)
private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
private var defined: Boolean = optionalValue.isDefined
private var updated: Boolean = false // whether value has been updated (but not removed)
@@ -91,7 +84,6 @@ private[sql] class GroupStateImpl[S](
defined = false
updated = false
removed = true
- timeoutTimestamp = NO_TIMESTAMP
}
override def setTimeoutDuration(durationMs: Long): Unit = {
@@ -100,21 +92,10 @@ private[sql] class GroupStateImpl[S](
"Cannot set timeout duration without enabling processing time timeout in " +
"map/flatMapGroupsWithState")
}
- if (!defined) {
- throw new IllegalStateException(
- "Cannot set timeout information without any state value, " +
- "state has either not been initialized, or has already been removed")
- }
-
if (durationMs <= 0) {
throw new IllegalArgumentException("Timeout duration must be positive")
}
- if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) {
- timeoutTimestamp = durationMs + batchProcessingTimeMs
- } else {
- // This is being called in a batch query, hence no processing timestamp.
- // Just ignore any attempts to set timeout.
- }
+ timeoutTimestamp = durationMs + batchProcessingTimeMs
}
override def setTimeoutDuration(duration: String): Unit = {
@@ -135,12 +116,7 @@ private[sql] class GroupStateImpl[S](
s"Timeout timestamp ($timestampMs) cannot be earlier than the " +
s"current watermark ($eventTimeWatermarkMs)")
}
- if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) {
- timeoutTimestamp = timestampMs
- } else {
- // This is being called in a batch query, hence no processing timestamp.
- // Just ignore any attempts to set timeout.
- }
+ timeoutTimestamp = timestampMs
}
@throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
@@ -213,11 +189,6 @@ private[sql] class GroupStateImpl[S](
"Cannot set timeout timestamp without enabling event time timeout in " +
"map/flatMapGroupsWithState")
}
- if (!defined) {
- throw new IllegalStateException(
- "Cannot set timeout timestamp without any state value, " +
- "state has either not been initialized, or has already been removed")
- }
}
}
@@ -225,4 +196,23 @@ private[sql] class GroupStateImpl[S](
private[sql] object GroupStateImpl {
// Value used represent the lack of valid timestamp as a long
val NO_TIMESTAMP = -1L
+
+ def createForStreaming[S](
+ optionalValue: Option[S],
+ batchProcessingTimeMs: Long,
+ eventTimeWatermarkMs: Long,
+ timeoutConf: GroupStateTimeout,
+ hasTimedOut: Boolean): GroupStateImpl[S] = {
+ new GroupStateImpl[S](
+ optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut)
+ }
+
+ def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = {
+ new GroupStateImpl[Any](
+ optionalValue = None,
+ batchProcessingTimeMs = NO_TIMESTAMP,
+ eventTimeWatermarkMs = NO_TIMESTAMP,
+ timeoutConf,
+ hasTimedOut = false)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index 46bfc297931fb..5f8973fd09460 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -123,7 +123,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
serialize(metadata, output)
return Some(tempPath)
} finally {
- IOUtils.closeQuietly(output)
+ output.close()
}
} catch {
case e: FileAlreadyExistsException =>
@@ -211,13 +211,17 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
}
override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] = {
+ assert(startId.isEmpty || endId.isEmpty || startId.get <= endId.get)
val files = fileManager.list(metadataPath, batchFilesFilter)
val batchIds = files
.map(f => pathToBatchId(f.getPath))
.filter { batchId =>
(endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get)
- }
- batchIds.sorted.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map {
+ }.sorted
+
+ verifyBatchIds(batchIds, startId, endId)
+
+ batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map {
case (batchId, metadataOption) =>
(batchId, metadataOption.get)
}
@@ -437,4 +441,51 @@ object HDFSMetadataLog {
}
}
}
+
+ /**
+ * Verify if batchIds are continuous and between `startId` and `endId`.
+ *
+ * @param batchIds the sorted ids to verify.
+ * @param startId the start id. If it's set, batchIds should start with this id.
+ * @param endId the start id. If it's set, batchIds should end with this id.
+ */
+ def verifyBatchIds(batchIds: Seq[Long], startId: Option[Long], endId: Option[Long]): Unit = {
+ // Verify that we can get all batches between `startId` and `endId`.
+ if (startId.isDefined || endId.isDefined) {
+ if (batchIds.isEmpty) {
+ throw new IllegalStateException(s"batch ${startId.orElse(endId).get} doesn't exist")
+ }
+ if (startId.isDefined) {
+ val minBatchId = batchIds.head
+ assert(minBatchId >= startId.get)
+ if (minBatchId != startId.get) {
+ val missingBatchIds = startId.get to minBatchId
+ throw new IllegalStateException(
+ s"batches (${missingBatchIds.mkString(", ")}) don't exist " +
+ s"(startId: $startId, endId: $endId)")
+ }
+ }
+
+ if (endId.isDefined) {
+ val maxBatchId = batchIds.last
+ assert(maxBatchId <= endId.get)
+ if (maxBatchId != endId.get) {
+ val missingBatchIds = maxBatchId to endId.get
+ throw new IllegalStateException(
+ s"batches (${missingBatchIds.mkString(", ")}) don't exist " +
+ s"(startId: $startId, endId: $endId)")
+ }
+ }
+ }
+
+ if (batchIds.nonEmpty) {
+ val minBatchId = batchIds.head
+ val maxBatchId = batchIds.last
+ val missingBatchIds = (minBatchId to maxBatchId).toSet -- batchIds
+ if (missingBatchIds.nonEmpty) {
+ throw new IllegalStateException(s"batches (${missingBatchIds.mkString(", ")}) " +
+ s"don't exist (startId: $startId, endId: $endId)")
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala
index 5551d12fa8ad2..66b11ecddf233 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala
@@ -17,15 +17,11 @@
package org.apache.spark.sql.execution.streaming
-import java.{util => ju}
-
-import scala.collection.mutable
-
import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.{Source => CodahaleSource}
-import org.apache.spark.util.Clock
+import org.apache.spark.sql.streaming.StreamingQueryProgress
/**
* Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to
@@ -39,14 +35,17 @@ class MetricsReporter(
// Metric names should not have . in them, so that all the metrics of a query are identified
// together in Ganglia as a single metric group
- registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond)
- registerGauge("processingRate-total", () => stream.lastProgress.inputRowsPerSecond)
- registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue())
-
- private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = {
+ registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0)
+ registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0)
+ registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L)
+
+ private def registerGauge[T](
+ name: String,
+ f: StreamingQueryProgress => T,
+ default: T): Unit = {
synchronized {
metricRegistry.register(name, new Gauge[T] {
- override def getValue: T = f()
+ override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default)
})
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 693933f95a231..db46fcd9dfe78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.streaming
import java.text.SimpleDateFormat
-import java.util.{Date, TimeZone, UUID}
+import java.util.{Date, UUID}
import scala.collection.mutable
import scala.collection.JavaConverters._
@@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
@@ -82,7 +83,7 @@ trait ProgressReporter extends Logging {
private var lastNoDataProgressEventTime = Long.MinValue
private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601
- timestampFormat.setTimeZone(TimeZone.getTimeZone("UTC"))
+ timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC"))
@volatile
protected var currentStatus: StreamingQueryStatus = {
@@ -266,7 +267,7 @@ trait ProgressReporter extends Logging {
Map(
"max" -> stats.max,
"min" -> stats.min,
- "avg" -> stats.avg).mapValues(formatTimestamp)
+ "avg" -> stats.avg.toLong).mapValues(formatTimestamp)
}.headOption.getOrElse(Map.empty) ++ watermarkTimestamp
ExecutionStats(numInputRows, stateOperators, eventTimeStats)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
new file mode 100644
index 0000000000000..e61a8eb628891
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io._
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.TimeUnit
+
+import org.apache.commons.io.IOUtils
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.{ManualClock, SystemClock}
+
+/**
+ * A source that generates increment long values with timestamps. Each generated row has two
+ * columns: a timestamp column for the generated time and an auto increment long column starting
+ * with 0L.
+ *
+ * This source supports the following options:
+ * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
+ * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
+ * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
+ * seconds.
+ * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
+ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
+ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
+ */
+class RateSourceProvider extends StreamSourceProvider with DataSourceRegister {
+
+ override def sourceSchema(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType) =
+ (shortName(), RateSourceProvider.SCHEMA)
+
+ override def createSource(
+ sqlContext: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ val params = CaseInsensitiveMap(parameters)
+
+ val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
+ if (rowsPerSecond <= 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
+ "must be positive")
+ }
+
+ val rampUpTimeSeconds =
+ params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
+ if (rampUpTimeSeconds < 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
+ "must not be negative")
+ }
+
+ val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
+ sqlContext.sparkContext.defaultParallelism)
+ if (numPartitions <= 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
+ "must be positive")
+ }
+
+ new RateStreamSource(
+ sqlContext,
+ metadataPath,
+ rowsPerSecond,
+ rampUpTimeSeconds,
+ numPartitions,
+ params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
+ )
+ }
+ override def shortName(): String = "rate"
+}
+
+object RateSourceProvider {
+ val SCHEMA =
+ StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
+
+ val VERSION = 1
+}
+
+class RateStreamSource(
+ sqlContext: SQLContext,
+ metadataPath: String,
+ rowsPerSecond: Long,
+ rampUpTimeSeconds: Long,
+ numPartitions: Int,
+ useManualClock: Boolean) extends Source with Logging {
+
+ import RateSourceProvider._
+ import RateStreamSource._
+
+ val clock = if (useManualClock) new ManualClock else new SystemClock
+
+ private val maxSeconds = Long.MaxValue / rowsPerSecond
+
+ if (rampUpTimeSeconds > maxSeconds) {
+ throw new ArithmeticException(
+ s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
+ s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
+ }
+
+ private val startTimeMs = {
+ val metadataLog =
+ new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
+ override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
+ val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+ writer.write("v" + VERSION + "\n")
+ writer.write(metadata.json)
+ writer.flush
+ }
+
+ override def deserialize(in: InputStream): LongOffset = {
+ val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
+ // HDFSMetadataLog guarantees that it never creates a partial file.
+ assert(content.length != 0)
+ if (content(0) == 'v') {
+ val indexOfNewLine = content.indexOf("\n")
+ if (indexOfNewLine > 0) {
+ val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
+ LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ }
+ }
+
+ metadataLog.get(0).getOrElse {
+ val offset = LongOffset(clock.getTimeMillis())
+ metadataLog.add(0, offset)
+ logInfo(s"Start time: $offset")
+ offset
+ }.offset
+ }
+
+ /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
+ @volatile private var lastTimeMs = startTimeMs
+
+ override def schema: StructType = RateSourceProvider.SCHEMA
+
+ override def getOffset: Option[Offset] = {
+ val now = clock.getTimeMillis()
+ if (lastTimeMs < now) {
+ lastTimeMs = now
+ }
+ Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
+ }
+
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
+ val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
+ assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
+ if (endSeconds > maxSeconds) {
+ throw new ArithmeticException("Integer overflow. Max offset with " +
+ s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
+ }
+ // Fix "lastTimeMs" for recovery
+ if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
+ lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
+ }
+ val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
+ val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
+ logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
+ s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
+
+ if (rangeStart == rangeEnd) {
+ return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema)
+ }
+
+ val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
+ val relativeMsPerValue =
+ TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
+
+ val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
+ val relative = math.round((v - rangeStart) * relativeMsPerValue)
+ InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
+ }
+ sqlContext.internalCreateDataFrame(rdd, schema)
+ }
+
+ override def stop(): Unit = {}
+
+ override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
+ s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
+}
+
+object RateStreamSource {
+
+ /** Calculate the end value we will emit at the time `seconds`. */
+ def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
+ // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
+ // Then speedDeltaPerSecond = 2
+ //
+ // seconds = 0 1 2 3 4 5 6
+ // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
+ // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
+ val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
+ if (seconds <= rampUpTimeSeconds) {
+ // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
+ // avoid overflow
+ if (seconds % 2 == 1) {
+ (seconds + 1) / 2 * speedDeltaPerSecond * seconds
+ } else {
+ seconds / 2 * speedDeltaPerSecond * (seconds + 1)
+ }
+ } else {
+ // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
+ val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
+ rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index bcf0d970f7ec1..33f81d98ca593 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.locks.ReentrantLock
+import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
@@ -148,15 +149,18 @@ class StreamExecution(
"logicalPlan must be initialized in StreamExecutionThread " +
s"but the current thread was ${Thread.currentThread}")
var nextSourceId = 0L
+ val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]()
val _logicalPlan = analyzedPlan.transform {
- case StreamingRelation(dataSource, _, output) =>
- // Materialize source to avoid creating it in every batch
- val metadataPath = s"$checkpointRoot/sources/$nextSourceId"
- val source = dataSource.createSource(metadataPath)
- nextSourceId += 1
- // We still need to use the previous `output` instead of `source.schema` as attributes in
- // "df.logicalPlan" has already used attributes of the previous `output`.
- StreamingExecutionRelation(source, output)
+ case streamingRelation@StreamingRelation(dataSource, _, output) =>
+ toExecutionRelationMap.getOrElseUpdate(streamingRelation, {
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$checkpointRoot/sources/$nextSourceId"
+ val source = dataSource.createSource(metadataPath)
+ nextSourceId += 1
+ // We still need to use the previous `output` instead of `source.schema` as attributes in
+ // "df.logicalPlan" has already used attributes of the previous `output`.
+ StreamingExecutionRelation(source, output)
+ })
}
sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
uniqueSources = sources.distinct
@@ -252,6 +256,8 @@ class StreamExecution(
*/
private def runBatches(): Unit = {
try {
+ sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString,
+ interruptOnCancel = true)
if (sparkSession.sessionState.conf.streamingMetricsEnabled) {
sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics)
}
@@ -289,6 +295,7 @@ class StreamExecution(
if (currentBatchId < 0) {
// We'll do this initialization only once
populateStartOffsets(sparkSessionToRunBatches)
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
} else {
constructNextBatch()
@@ -308,6 +315,7 @@ class StreamExecution(
logDebug(s"batch ${currentBatchId} committed")
// We'll increase currentBatchId after we complete processing current batch's data
currentBatchId += 1
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
} else {
currentStatus = currentStatus.copy(isDataAvailable = false)
updateStatusMessage("Waiting for data to arrive")
@@ -421,7 +429,10 @@ class StreamExecution(
availableOffsets = nextOffsets.toStreamProgress(sources)
/* Initialize committed offsets to a committed batch, which at this
* is the second latest batch id in the offset log. */
- offsetLog.get(latestBatchId - 1).foreach { secondLatestBatchId =>
+ if (latestBatchId != 0) {
+ val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse {
+ throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
+ }
committedOffsets = secondLatestBatchId.toStreamProgress(sources)
}
@@ -560,10 +571,14 @@ class StreamExecution(
// Now that we've updated the scheduler's persistent checkpoint, it is safe for the
// sources to discard data from the previous batch.
- val prevBatchOff = offsetLog.get(currentBatchId - 1)
- if (prevBatchOff.isDefined) {
- prevBatchOff.get.toStreamProgress(sources).foreach {
- case (src, off) => src.commit(off)
+ if (currentBatchId != 0) {
+ val prevBatchOff = offsetLog.get(currentBatchId - 1)
+ if (prevBatchOff.isDefined) {
+ prevBatchOff.get.toStreamProgress(sources).foreach {
+ case (src, off) => src.commit(off)
+ }
+ } else {
+ throw new IllegalStateException(s"batch $currentBatchId doesn't exist")
}
}
@@ -623,7 +638,8 @@ class StreamExecution(
// Rewire the plan to use the new attributes that were returned by the source.
val replacementMap = AttributeMap(replacements)
val triggerLogicalPlan = withNewSources transformAllExpressions {
- case a: Attribute if replacementMap.contains(a) => replacementMap(a)
+ case a: Attribute if replacementMap.contains(a) =>
+ replacementMap(a).withMetadata(a.metadata)
case ct: CurrentTimestamp =>
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
ct.dataType)
@@ -684,8 +700,11 @@ class StreamExecution(
// intentionally
state.set(TERMINATED)
if (microBatchThread.isAlive) {
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
microBatchThread.interrupt()
microBatchThread.join()
+ // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
}
logInfo(s"Query $prettyIdString was stopped")
}
@@ -758,7 +777,7 @@ class StreamExecution(
if (streamDeathCause != null) {
throw streamDeathCause
}
- if (noNewData) {
+ if (noNewData || !isActive) {
return
}
}
@@ -825,6 +844,11 @@ class StreamExecution(
}
}
+ private def getBatchDescriptionString: String = {
+ val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString
+ Option(name).map(_ + "
").getOrElse("") +
+ s"id = $id
runId = $runId
batch = $batchDescription"
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 1426728f9b550..ef48fffe1d980 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.streaming.state
import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException}
+import java.nio.channels.ClosedChannelException
import java.util.Locale
import scala.collection.JavaConverters._
@@ -202,13 +203,22 @@ private[state] class HDFSBackedStateStoreProvider(
/** Abort all the updates made on this store. This store will not be usable any more. */
override def abort(): Unit = {
verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
+ try {
+ state = ABORTED
+ if (tempDeltaFileStream != null) {
+ tempDeltaFileStream.close()
+ }
+ if (tempDeltaFile != null) {
+ fs.delete(tempDeltaFile, true)
+ }
+ } catch {
+ case c: ClosedChannelException =>
+ // This can happen when underlying file output stream has been closed before the
+ // compression stream.
+ logDebug(s"Error aborting version $newVersion into $this", c)
- state = ABORTED
- if (tempDeltaFileStream != null) {
- tempDeltaFileStream.close()
- }
- if (tempDeltaFile != null) {
- fs.delete(tempDeltaFile, true)
+ case e: Exception =>
+ logWarning(s"Error aborting version $newVersion into $this", e)
}
logInfo(s"Aborted version $newVersion for $this")
}
@@ -438,9 +448,11 @@ private[state] class HDFSBackedStateStoreProvider(
private def writeSnapshotFile(version: Long, map: MapType): Unit = {
val fileToWrite = snapshotFile(version)
+ val tempFile =
+ new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}")
var output: DataOutputStream = null
Utils.tryWithSafeFinally {
- output = compressStream(fs.create(fileToWrite, false))
+ output = compressStream(fs.create(tempFile, false))
val iter = map.entrySet().iterator()
while(iter.hasNext) {
val entry = iter.next()
@@ -455,6 +467,12 @@ private[state] class HDFSBackedStateStoreProvider(
} {
if (output != null) output.close()
}
+ if (fs.exists(fileToWrite)) {
+ // Skip rename if the file is alreayd created.
+ fs.delete(tempFile, true)
+ } else if (!fs.rename(tempFile, fileToWrite)) {
+ throw new IOException(s"Failed to rename $tempFile to $fileToWrite")
+ }
logInfo(s"Written snapshot file for version $version of $this at $fileToWrite")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 8dbda298c87bc..d4de046787b9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -102,8 +102,13 @@ trait WatermarkSupport extends UnaryExecNode {
}
/** Predicate based on keys that matches data older than the watermark */
- lazy val watermarkPredicateForKeys: Option[Predicate] =
- watermarkExpression.map(newPredicate(_, keyExpressions))
+ lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.flatMap { e =>
+ if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) {
+ Some(newPredicate(e, keyExpressions))
+ } else {
+ None
+ }
+ }
/** Predicate based on the child output that matches data older than the watermark. */
lazy val watermarkPredicateForData: Option[Predicate] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
index 23fc0bd0bce13..460fc946c3e6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
@@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
private val listener = parent.listener
override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized {
- val parameterExecutionId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterExecutionId != null && parameterExecutionId.nonEmpty,
"Missing execution id parameter")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index b4a91230a0012..e0c8cb3487c92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -101,6 +101,9 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000)
+ private val retainedStages = conf.getInt("spark.ui.retainedStages",
+ SparkUI.DEFAULT_RETAINED_STAGES)
+
private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]()
// Old data in the following fields must be removed in "trimExecutionsIfNecessary".
@@ -113,7 +116,7 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
*/
private val _jobIdToExecutionId = mutable.HashMap[Long, Long]()
- private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]()
+ private val _stageIdToStageMetrics = mutable.LinkedHashMap[Long, SQLStageMetrics]()
private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]()
@@ -207,6 +210,14 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
}
}
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized {
+ val extraStages = _stageIdToStageMetrics.size - retainedStages
+ if (extraStages > 0) {
+ val toRemove = _stageIdToStageMetrics.take(extraStages).keys
+ _stageIdToStageMetrics --= toRemove
+ }
+ }
+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
if (taskEnd.taskMetrics != null) {
updateTaskAccumulatorValues(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
index c9f5d3b3d92d7..dfa1100c37a0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
@@ -145,10 +145,13 @@ private[window] final class AggregateProcessor(
/** Update the buffer. */
def update(input: InternalRow): Unit = {
- updateProjection(join(buffer, input))
+ // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that
+ // MutableProjection makes copies of the complex input objects it buffer.
+ val copy = input.copy()
+ updateProjection(join(buffer, copy))
var i = 0
while (i < numImperatives) {
- imperatives(i).update(buffer, input)
+ imperatives(i).update(buffer, copy)
i += 1
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 950a6794a74a3..b9c932ae21727 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -282,6 +282,7 @@ case class WindowExec(
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+ val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold
val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
// Start processing.
@@ -312,7 +313,8 @@ case class WindowExec(
val inputFields = child.output.length
val buffer: ExternalAppendOnlyUnsafeRowArray =
- new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
+
var bufferIterator: Iterator[UnsafeRow] = _
val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index af2b4fb92062b..156002ef58fbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -195,15 +195,6 @@ private[window] final class SlidingWindowFunctionFrame(
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
- // Add all rows to the buffer for which the input row value is equal to or less than
- // the output row upper bound.
- while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
- buffer.add(nextRow.copy())
- nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
- inputHighIndex += 1
- bufferUpdated = true
- }
-
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
@@ -212,6 +203,19 @@ private[window] final class SlidingWindowFunctionFrame(
bufferUpdated = true
}
+ // Add all rows to the buffer for which the input row value is equal to or less than
+ // the output row upper bound.
+ while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
+ if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) {
+ inputLowIndex += 1
+ } else {
+ buffer.add(nextRow.copy())
+ bufferUpdated = true
+ }
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
+ inputHighIndex += 1
+ }
+
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
processor.initialize(input.length)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f07e04368389f..067b6d528d3f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -23,13 +23,13 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import scala.util.Try
import scala.util.control.NonFatal
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.SQLConf
@@ -1019,7 +1019,8 @@ object functions {
* @since 1.5.0
*/
def broadcast[T](df: Dataset[T]): Dataset[T] = {
- Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc)
+ Dataset[T](df.sparkSession,
+ ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc)
}
/**
@@ -2291,7 +2292,8 @@ object functions {
}
/**
- * Left-pad the string column with
+ * Left-pad the string column with pad to a length of len. If the string column is longer
+ * than len, the return value is shortened to len characters.
*
* @group string_funcs
* @since 1.5.0
@@ -2349,7 +2351,8 @@ object functions {
def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) }
/**
- * Right-padded with pad to a length of len.
+ * Right-pad the string column with pad to a length of len. If the string column is longer
+ * than len, the return value is shortened to len characters.
*
* @group string_funcs
* @since 1.5.0
@@ -2793,8 +2796,6 @@ object functions {
* @group datetime_funcs
* @since 2.0.0
*/
- @Experimental
- @InterfaceStability.Evolving
def window(
timeColumn: Column,
windowDuration: String,
@@ -2847,8 +2848,6 @@ object functions {
* @group datetime_funcs
* @since 2.0.0
*/
- @Experimental
- @InterfaceStability.Evolving
def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = {
window(timeColumn, windowDuration, slideDuration, "0 second")
}
@@ -2886,8 +2885,6 @@ object functions {
* @group datetime_funcs
* @since 2.0.0
*/
- @Experimental
- @InterfaceStability.Evolving
def window(timeColumn: Column, windowDuration: String): Column = {
window(timeColumn, windowDuration, windowDuration, "0 second")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 2b14eca919fa4..2a801d87b12eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.internal
import org.apache.spark.SparkConf
import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration}
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
@@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder(
*/
protected def newBuilder: NewBuilder
+ /**
+ * Session extensions defined in the [[SparkSession]].
+ */
+ protected def extensions: SparkSessionExtensions = session.extensions
+
/**
* Extract entries from `SparkConf` and put them in the `SQLConf`
*/
@@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder(
*
* Note: this depends on the `conf` field.
*/
- protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+ protected lazy val sqlParser: ParserInterface = {
+ extensions.buildParser(session, new SparkSqlParser(conf))
+ }
/**
* ResourceLoader that is used to load function resources and jars.
@@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `analyzer` function.
*/
- protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customResolutionRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildResolutionRules(session)
+ }
/**
* Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of
@@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `analyzer` function.
*/
- protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildPostHocResolutionRules(session)
+ }
/**
* Custom check rules to add to the Analyzer. Prefer overriding this instead of creating
@@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `analyzer` function.
*/
- protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil
+ protected def customCheckRules: Seq[LogicalPlan => Unit] = {
+ extensions.buildCheckRules(session)
+ }
/**
* Logical query plan optimizer.
@@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `optimizer` function.
*/
- protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildOptimizerRules(session)
+ }
/**
* Planner that converts optimized logical plans to physical plans.
@@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `planner` function.
*/
- protected def customPlanningStrategies: Seq[Strategy] = Nil
+ protected def customPlanningStrategies: Seq[Strategy] = {
+ extensions.buildPlannerStrategies(session)
+ }
/**
* Create a query execution object.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index aebb663df5c92..f5338acdd68d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.internal
import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
@@ -98,14 +99,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
CatalogImpl.makeDataset(tables, sparkSession)
}
+ /**
+ * Returns a Table for the given table/view or temporary view.
+ *
+ * Note that this function requires the table already exists in the Catalog.
+ *
+ * If the table metadata retrieval failed due to any reason (e.g., table serde class
+ * is not accessible or the table type is not accepted by Spark SQL), this function
+ * still returns the corresponding Table without the description and tableType)
+ */
private def makeTable(tableIdent: TableIdentifier): Table = {
- val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)
+ val metadata = try {
+ Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent))
+ } catch {
+ case NonFatal(_) => None
+ }
val isTemp = sessionCatalog.isTemporaryTable(tableIdent)
new Table(
name = tableIdent.table,
- database = metadata.identifier.database.orNull,
- description = metadata.comment.orNull,
- tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name,
+ database = metadata.map(_.identifier.database).getOrElse(tableIdent.database).orNull,
+ description = metadata.map(_.comment.orNull).orNull,
+ tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull,
isTemporary = isTemp)
}
@@ -197,7 +211,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* `AnalysisException` when no `Table` can be found.
*/
override def getTable(dbName: String, tableName: String): Table = {
- makeTable(TableIdentifier(tableName, Option(dbName)))
+ if (tableExists(dbName, tableName)) {
+ makeTable(TableIdentifier(tableName, Option(dbName)))
+ } else {
+ throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'")
+ }
}
/**
@@ -444,13 +462,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def refreshTable(tableName: String): Unit = {
val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
- // Temp tables: refresh (or invalidate) any metadata/data cached in the plan recursively.
- // Non-temp tables: refresh the metadata cache.
- sessionCatalog.refreshTable(tableIdent)
+ val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)
+ val table = sparkSession.table(tableIdent)
+
+ if (tableMetadata.tableType == CatalogTableType.VIEW) {
+ // Temp or persistent views: refresh (or invalidate) any metadata/data cached
+ // in the plan recursively.
+ table.queryExecution.analyzed.refresh()
+ } else {
+ // Non-temp tables: refresh the metadata cache.
+ sessionCatalog.refreshTable(tableIdent)
+ }
// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
- val table = sparkSession.table(tableIdent)
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
index b9515ec7bca2a..eca612f06f9bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
@@ -31,7 +31,8 @@ object HiveSerDe {
"sequencefile" ->
HiveSerDe(
inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"),
- outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")),
+ outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"),
+ serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")),
"rcfile" ->
HiveSerDe(
@@ -54,7 +55,8 @@ object HiveSerDe {
"textfile" ->
HiveSerDe(
inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
+ outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"),
+ serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")),
"avro" ->
HiveSerDe(
@@ -73,6 +75,7 @@ object HiveSerDe {
val key = source.toLowerCase(Locale.ROOT) match {
case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
case s if s.startsWith("org.apache.spark.sql.orc") => "orc"
+ case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc"
case s if s.equals("orcfile") => "orc"
case s if s.equals("parquetfile") => "parquet"
case s if s.equals("avrofile") => "avro"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 0289471bf841a..7202f1222d10f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql.internal
+import java.net.URL
+import java.util.Locale
+
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FsUrlStreamHandlerFactory
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.internal.Logging
@@ -86,35 +90,42 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
/**
* A catalog that interacts with external systems.
*/
- lazy val externalCatalog: ExternalCatalog =
- SharedState.reflect[ExternalCatalog, SparkConf, Configuration](
+ lazy val externalCatalog: ExternalCatalog = {
+ val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration](
SharedState.externalCatalogClassName(sparkContext.conf),
sparkContext.conf,
sparkContext.hadoopConfiguration)
- // Create the default database if it doesn't exist.
- {
val defaultDbDefinition = CatalogDatabase(
SessionCatalog.DEFAULT_DATABASE,
"default database",
CatalogUtils.stringToURI(warehousePath),
Map())
- // Initialize default database if it doesn't exist
+ // Create default database if it doesn't exist
if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) {
// There may be another Spark application creating default database at the same time, here we
// set `ignoreIfExists = true` to avoid `DatabaseAlreadyExists` exception.
externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true)
}
+
+ // Make sure we propagate external catalog events to the spark listener bus
+ externalCatalog.addListener(new ExternalCatalogEventListener {
+ override def onEvent(event: ExternalCatalogEvent): Unit = {
+ sparkContext.listenerBus.post(event)
+ }
+ })
+
+ externalCatalog
}
/**
* A manager for global temporary views.
*/
- val globalTempViewManager: GlobalTempViewManager = {
+ lazy val globalTempViewManager: GlobalTempViewManager = {
// System preserved database should not exists in metastore. However it's hard to guarantee it
// for every session, because case-sensitivity differs. Here we always lowercase it to make our
// life easier.
- val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase
+ val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase(Locale.ROOT)
if (externalCatalog.databaseExists(globalTempDB)) {
throw new SparkException(
s"$globalTempDB is a system preserved database, please rename your existing database " +
@@ -145,7 +156,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
}
}
-object SharedState {
+object SharedState extends Logging {
+ try {
+ URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory())
+ } catch {
+ case e: Error =>
+ logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory")
+ }
private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index f541996b651e9..20e634c06b610 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -43,10 +43,6 @@ private case object OracleDialect extends JdbcDialect {
// Not sure if there is a more robust way to identify the field as a float (or other
// numeric types that do not specify a scale.
case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
- case 1 => Option(BooleanType)
- case 3 | 5 | 10 => Option(IntegerType)
- case 19 if scale == 0L => Option(LongType)
- case 19 if scale == 4L => Option(FloatType)
case _ => None
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index ff8b15b3ff3ff..86eeb2f7dd419 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -163,16 +163,13 @@ trait StreamSinkProvider {
@InterfaceStability.Stable
trait CreatableRelationProvider {
/**
- * Save the DataFrame to the destination and return a relation with the given parameters based on
- * the contents of the given DataFrame. The mode specifies the expected behavior of createRelation
- * when data already exists.
- * Right now, there are three modes, Append, Overwrite, and ErrorIfExists.
- * Append mode means that when saving a DataFrame to a data source, if data already exists,
- * contents of the DataFrame are expected to be appended to existing data.
- * Overwrite mode means that when saving a DataFrame to a data source, if data already exists,
- * existing data is expected to be overwritten by the contents of the DataFrame.
- * ErrorIfExists mode means that when saving a DataFrame to a data source,
- * if data already exists, an exception is expected to be thrown.
+ * Saves a DataFrame to a destination (using data source-specific parameters)
+ *
+ * @param sqlContext SQLContext
+ * @param mode specifies what happens when the destination already exists
+ * @param parameters data source-specific parameters
+ * @param data DataFrame to save (i.e. the rows after executing the query)
+ * @return Relation with a known schema
*
* @since 1.3.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 746b2a94f102d..7e8e6394b4862 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.command.DDLUtils
@@ -35,7 +35,6 @@ import org.apache.spark.sql.types.StructType
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging {
/**
@@ -164,7 +163,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* Loads a JSON file stream and returns the results as a `DataFrame`.
*
* JSON Lines (newline-delimited JSON) is supported by
- * default. For JSON (one record per file), set the `wholeFile` option to true.
+ * default. For JSON (one record per file), set the `multiLine` option to true.
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
@@ -206,7 +205,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
- * `wholeFile` (default `false`): parse one record, which may span multiple lines,
+ * `multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
*
*
@@ -277,7 +276,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
- * `wholeFile` (default `false`): parse one record, which may span multiple lines.
+ * `multiLine` (default `false`): parse one record, which may span multiple lines.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 0d2611f9bbcce..0be69b98abc8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
@@ -29,13 +29,11 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink}
/**
- * :: Experimental ::
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
* key-value stores, etc). Use `Dataset.writeStream` to access this.
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
@@ -269,12 +267,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
useTempCheckpointLocation = true,
trigger = trigger)
} else {
- val (useTempCheckpointLocation, recoverFromCheckpointLocation) =
- if (source == "console") {
- (true, false)
- } else {
- (false, true)
- }
val dataSource =
DataSource(
df.sparkSession,
@@ -287,8 +279,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
df,
dataSource.createSink(outputMode),
outputMode,
- useTempCheckpointLocation = useTempCheckpointLocation,
- recoverFromCheckpointLocation = recoverFromCheckpointLocation,
+ useTempCheckpointLocation = source == "console",
+ recoverFromCheckpointLocation = true,
trigger = trigger)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
index c659ac7fcf3d9..04a956b70b022 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
@@ -212,7 +212,7 @@ trait GroupState[S] extends LogicalGroupState[S] {
@throws[IllegalArgumentException]("when updating with null")
def update(newState: S): Unit
- /** Remove this state. Note that this resets any timeout configuration as well. */
+ /** Remove this state. */
def remove(): Unit
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala
index 9ba1fc01cbd30..a033575d3d38f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala
@@ -23,11 +23,10 @@ import scala.concurrent.duration.Duration
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.unsafe.types.CalendarInterval
/**
- * :: Experimental ::
* A trigger that runs a query periodically based on the processing time. If `interval` is 0,
* the query will run as fast as possible.
*
@@ -49,7 +48,6 @@ import org.apache.spark.unsafe.types.CalendarInterval
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0")
case class ProcessingTime(intervalMs: Long) extends Trigger {
@@ -57,12 +55,10 @@ case class ProcessingTime(intervalMs: Long) extends Trigger {
}
/**
- * :: Experimental ::
* Used to create [[ProcessingTime]] triggers for [[StreamingQuery]]s.
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0")
object ProcessingTime {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
index 12a1bb1db5779..f2dfbe42260d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
@@ -19,16 +19,14 @@ package org.apache.spark.sql.streaming
import java.util.UUID
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.SparkSession
/**
- * :: Experimental ::
* A handle to a query that is executing continuously in the background as new data arrives.
* All these methods are thread-safe.
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
trait StreamingQuery {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
index 234a1166a1953..03aeb14de502a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
/**
- * :: Experimental ::
* Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception
* that caused the failure.
* @param message Message of this exception
@@ -29,7 +28,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability}
* @param endOffset Ending offset in json of the range of data in exception occurred
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
class StreamingQueryException private[sql](
private val queryDebugString: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index c376913516ef7..6aa82b89ede81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -19,17 +19,15 @@ package org.apache.spark.sql.streaming
import java.util.UUID
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.scheduler.SparkListenerEvent
/**
- * :: Experimental ::
* Interface for listening to events related to [[StreamingQuery StreamingQueries]].
* @note The methods are not thread-safe as they may be called from different threads.
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
abstract class StreamingQueryListener {
@@ -66,32 +64,26 @@ abstract class StreamingQueryListener {
/**
- * :: Experimental ::
* Companion object of [[StreamingQueryListener]] that defines the listener events.
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
object StreamingQueryListener {
/**
- * :: Experimental ::
* Base type of [[StreamingQueryListener]] events
* @since 2.0.0
*/
- @Experimental
@InterfaceStability.Evolving
trait Event extends SparkListenerEvent
/**
- * :: Experimental ::
* Event representing the start of a query
* @param id An unique query id that persists across restarts. See `StreamingQuery.id()`.
* @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
* @param name User-specified name of the query, null if not specified.
* @since 2.1.0
*/
- @Experimental
@InterfaceStability.Evolving
class QueryStartedEvent private[sql](
val id: UUID,
@@ -99,17 +91,14 @@ object StreamingQueryListener {
val name: String) extends Event
/**
- * :: Experimental ::
* Event representing any progress updates in a query.
* @param progress The query progress updates.
* @since 2.1.0
*/
- @Experimental
@InterfaceStability.Evolving
class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event
/**
- * :: Experimental ::
* Event representing that termination of a query.
*
* @param id An unique query id that persists across restarts. See `StreamingQuery.id()`.
@@ -118,7 +107,6 @@ object StreamingQueryListener {
* with an exception. Otherwise, it will be `None`.
* @since 2.1.0
*/
- @Experimental
@InterfaceStability.Evolving
class QueryTerminatedEvent private[sql](
val id: UUID,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 7810d9f6e9642..002c45413b4c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
@@ -34,12 +34,10 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
- * :: Experimental ::
- * A class to manage all the [[StreamingQuery]] active on a `SparkSession`.
+ * A class to manage all the [[StreamingQuery]] active in a `SparkSession`.
*
* @since 2.0.0
*/
-@Experimental
@InterfaceStability.Evolving
class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
index 687b1267825fe..a0c9bcc8929eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
@@ -22,10 +22,9 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
/**
- * :: Experimental ::
* Reports information about the instantaneous status of a streaming query.
*
* @param message A human readable description of what the stream is currently doing.
@@ -35,7 +34,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability}
*
* @since 2.1.0
*/
-@Experimental
@InterfaceStability.Evolving
class StreamingQueryStatus protected[sql](
val message: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index 35fe6b8605fad..5171852c48b9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -29,13 +29,11 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.InterfaceStability
/**
- * :: Experimental ::
* Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger.
*/
-@Experimental
@InterfaceStability.Evolving
class StateOperatorProgress private[sql](
val numRowsTotal: Long,
@@ -51,10 +49,11 @@ class StateOperatorProgress private[sql](
("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated))
}
+
+ override def toString: String = prettyJson
}
/**
- * :: Experimental ::
* Information about progress made in the execution of a [[StreamingQuery]] during
* a trigger. Each event relates to processing done for a single trigger of the streaming
* query. Events are emitted even when no new data is available to be processed.
@@ -80,7 +79,6 @@ class StateOperatorProgress private[sql](
* @param sources detailed statistics on data being read from each of the streaming sources.
* @since 2.1.0
*/
-@Experimental
@InterfaceStability.Evolving
class StreamingQueryProgress private[sql](
val id: UUID,
@@ -139,7 +137,6 @@ class StreamingQueryProgress private[sql](
}
/**
- * :: Experimental ::
* Information about progress made for a source in the execution of a [[StreamingQuery]]
* during a trigger. See [[StreamingQueryProgress]] for more information.
*
@@ -152,7 +149,6 @@ class StreamingQueryProgress private[sql](
* Spark.
* @since 2.1.0
*/
-@Experimental
@InterfaceStability.Evolving
class SourceProgress protected[sql](
val description: String,
@@ -191,14 +187,12 @@ class SourceProgress protected[sql](
}
/**
- * :: Experimental ::
* Information about progress made for a sink in the execution of a [[StreamingQuery]]
* during a trigger. See [[StreamingQueryProgress]] for more information.
*
* @param description Description of the source corresponding to this status.
* @since 2.1.0
*/
-@Experimental
@InterfaceStability.Evolving
class SinkProgress protected[sql](
val description: String) extends Serializable {
diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index cfd7889b4ac2c..c6973bf41d34b 100644
--- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,3 +1,7 @@
org.apache.spark.sql.sources.FakeSourceOne
org.apache.spark.sql.sources.FakeSourceTwo
org.apache.spark.sql.sources.FakeSourceThree
+org.apache.spark.sql.sources.FakeSourceFour
+org.apache.fakesource.FakeExternalSourceOne
+org.apache.fakesource.FakeExternalSourceTwo
+org.apache.fakesource.FakeExternalSourceThree
diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
index f62b10ca0037b..492a405d7ebbd 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql
@@ -32,3 +32,27 @@ select 1 - 2;
select 2 * 5;
select 5 % 3;
select pmod(-7, 3);
+
+-- math functions
+select cot(1);
+select cot(null);
+select cot(0);
+select cot(-1);
+
+-- ceil and ceiling
+select ceiling(0);
+select ceiling(1);
+select ceil(1234567890123456);
+select ceiling(1234567890123456);
+select ceil(0.01);
+select ceiling(-0.10);
+
+-- floor
+select floor(0);
+select floor(1);
+select floor(1234567890123456);
+select floor(0.01);
+select floor(-0.10);
+
+-- comparison operator
+select 1 > 0.00001
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
index 5fae571945e41..629df59cff8b3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -40,4 +40,6 @@ SELECT CAST('-9223372036854775809' AS long);
SELECT CAST('9223372036854775807' AS long);
SELECT CAST('9223372036854775808' AS long);
+DESC FUNCTION boolean;
+DESC FUNCTION EXTENDED boolean;
-- TODO: migrate all cast tests here.
diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
index ad0f885f63d3d..2909024e4c9f7 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
@@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column
-- Change column in partition spec (not supported yet)
CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d);
ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT;
+ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C';
-- DROP TEST TABLE
DROP TABLE test_change;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql
new file mode 100644
index 0000000000000..3e2447723e576
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql
@@ -0,0 +1,3 @@
+-- binary type
+select x'00' < x'0f';
+select x'00' < x'ff';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
index 3fd1c37e71795..b42e92436da43 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
@@ -2,3 +2,20 @@
-- [SPARK-16836] current_date and current_timestamp literals
select current_date = current_date(), current_timestamp = current_timestamp();
+
+-- [SPARK-22333]: timeFunctionCall has conflicts with columnReference
+create temporary view ttf1 as select * from values
+ (1, 2),
+ (2, 3)
+ as ttf1(current_date, current_timestamp);
+
+select current_date, current_timestamp from ttf1;
+
+create temporary view ttf2 as select * from values
+ (1, 2),
+ (2, 3)
+ as ttf2(a, b);
+
+select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2;
+
+select a, b from ttf2 order by a, current_date;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql
index 6de4cf0d5afa1..91b966829f8fb 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql
@@ -1,4 +1,5 @@
CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet
+ OPTIONS (a '1', b '2')
PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS
COMMENT 'table_comment';
@@ -13,6 +14,8 @@ CREATE TEMPORARY VIEW temp_Data_Source_View
CREATE VIEW v AS SELECT * FROM t;
+ALTER TABLE t SET TBLPROPERTIES (e = '3');
+
ALTER TABLE t ADD PARTITION (c='Us', d=1);
DESCRIBE t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql
index f8135389a9e5a..8aff4cb524199 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql
@@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co
ORDER BY GROUPING(course), GROUPING(year), course, year;
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course);
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course);
-SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;
\ No newline at end of file
+SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;
+
+-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS
+SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2);
+SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b);
+SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
index 9c8d851e36e9b..928f766b4add2 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql
@@ -49,7 +49,10 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
-- group by ordinal followed by having
select count(a), a from (select 1 as a) tmp group by 2 having a > 0;
--- turn of group by ordinal
+-- mixed cases: group-by ordinals and aliases
+select a, a AS k, count(b) from data group by k, 1;
+
+-- turn off group by ordinal
set spark.sql.groupByOrdinal=false;
-- can now group by negative literal
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index 4d0ed43153004..c5070b734d521 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -35,3 +35,37 @@ FROM testData;
-- Aggregate with foldable input and multiple distinct groups.
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
+
+-- Aliases in SELECT could be used in GROUP BY
+SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
+SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1;
+
+-- Aggregate functions cannot be used in GROUP BY
+SELECT COUNT(b) AS k FROM testData GROUP BY k;
+
+-- Test data.
+CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
+(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v);
+SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a;
+
+-- turn off group by aliases
+set spark.sql.groupByAliases=false;
+
+-- Check analysis exceptions
+SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
+
+-- Aggregate with empty input and non-empty GroupBy expressions.
+SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a;
+
+-- Aggregate with empty input and empty GroupBy expressions.
+SELECT COUNT(1) FROM testData WHERE false;
+SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t;
+
+-- Aggregate with empty GroupBy expressions and filter on top
+SELECT 1 from (
+ SELECT 1 AS z,
+ MIN(a.x)
+ FROM (select 1 as x) a
+ WHERE false
+) b
+where b.z != b.z
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 364c022d959dc..868a911e787f6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2;
-- SPARK-11032: resolve having correctly
SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
+
+-- SPARK-20329: make sure we handle timezones correctly
+SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
index 2b5b692d29ef4..f1461032065ad 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
@@ -23,3 +23,7 @@ SELECT float(1), double(1), decimal(1);
SELECT date("2014-04-04"), timestamp(date("2014-04-04"));
-- error handling: only one argument
SELECT string(1, 2);
+
+-- SPARK-21555: RuntimeReplaceable used in group by
+CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st);
+SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value");
diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql
index e56344dc4de80..93a1238ab18c2 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/struct.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql
@@ -18,3 +18,10 @@ SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x;
-- Prepend a column to a struct
SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x;
+
+-- Select a column from a struct
+SELECT ID, STRUCT(ST.*).C NST FROM tbl_x;
+SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x;
+
+-- Select an alias from a struct
+SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x;
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
index fb0d07fbdace7..1661209093fc4 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
@@ -173,6 +173,16 @@ WHERE t1a = (SELECT max(t2a)
HAVING count(*) >= 0)
OR t1i > '2014-12-31';
+-- TC 02.03.01
+SELECT t1a
+FROM t1
+WHERE t1a = (SELECT max(t2a)
+ FROM t2
+ WHERE t2c = t1c
+ GROUP BY t2c
+ HAVING count(*) >= 1)
+OR t1i > '2014-12-31';
+
-- TC 02.04
-- t1 on the right of an outer join
-- can be reduced to inner join
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql
new file mode 100644
index 0000000000000..c800fc3d49891
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -0,0 +1,69 @@
+-- Test data.
+CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
+(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null)
+AS testData(val, cate);
+
+-- RowsBetween
+SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData
+ORDER BY cate, val;
+SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val
+ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val;
+
+-- RangeBetween
+SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData
+ORDER BY cate, val;
+SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val
+RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val;
+
+-- RangeBetween with reverse OrderBy
+SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC
+RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val;
+
+-- Window functions
+SELECT val, cate,
+max(val) OVER w AS max,
+min(val) OVER w AS min,
+min(val) OVER w AS min,
+count(val) OVER w AS count,
+sum(val) OVER w AS sum,
+avg(val) OVER w AS avg,
+stddev(val) OVER w AS stddev,
+first_value(val) OVER w AS first_value,
+first_value(val, true) OVER w AS first_value_ignore_null,
+first_value(val, false) OVER w AS first_value_contain_null,
+last_value(val) OVER w AS last_value,
+last_value(val, true) OVER w AS last_value_ignore_null,
+last_value(val, false) OVER w AS last_value_contain_null,
+rank() OVER w AS rank,
+dense_rank() OVER w AS dense_rank,
+cume_dist() OVER w AS cume_dist,
+percent_rank() OVER w AS percent_rank,
+ntile(2) OVER w AS ntile,
+row_number() OVER w AS row_number,
+var_pop(val) OVER w AS var_pop,
+var_samp(val) OVER w AS var_samp,
+approx_count_distinct(val) OVER w AS approx_count_distinct
+FROM testData
+WINDOW w AS (PARTITION BY cate ORDER BY val)
+ORDER BY cate, val;
+
+-- Null inputs
+SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val;
+
+-- OrderBy not specified
+SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val;
+
+-- Over clause is empty
+SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val;
+
+-- first_value()/last_value() over ()
+SELECT val, cate,
+first_value(false) OVER w AS first_value,
+first_value(true, true) OVER w AS first_value_ignore_null,
+first_value(false, false) OVER w AS first_value_contain_null,
+last_value(false) OVER w AS last_value,
+last_value(true, true) OVER w AS last_value_ignore_null,
+last_value(false, false) OVER w AS last_value_contain_null
+FROM testData
+WINDOW w AS ()
+ORDER BY cate, val;
diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
index ce42c016a7100..3811cd2c30986 100644
--- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 28
+-- Number of queries: 44
-- !query 0
@@ -224,3 +224,135 @@ select pmod(-7, 3)
struct
-- !query 27 output
2
+
+
+-- !query 28
+select cot(1)
+-- !query 28 schema
+struct<>
+-- !query 28 output
+org.apache.spark.sql.AnalysisException
+Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7
+
+
+-- !query 29
+select cot(null)
+-- !query 29 schema
+struct<>
+-- !query 29 output
+org.apache.spark.sql.AnalysisException
+Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7
+
+
+-- !query 30
+select cot(0)
+-- !query 30 schema
+struct<>
+-- !query 30 output
+org.apache.spark.sql.AnalysisException
+Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7
+
+
+-- !query 31
+select cot(-1)
+-- !query 31 schema
+struct<>
+-- !query 31 output
+org.apache.spark.sql.AnalysisException
+Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7
+
+
+-- !query 32
+select ceiling(0)
+-- !query 32 schema
+struct
+-- !query 32 output
+0
+
+
+-- !query 33
+select ceiling(1)
+-- !query 33 schema
+struct
+-- !query 33 output
+1
+
+
+-- !query 34
+select ceil(1234567890123456)
+-- !query 34 schema
+struct
+-- !query 34 output
+1234567890123456
+
+
+-- !query 35
+select ceiling(1234567890123456)
+-- !query 35 schema
+struct
+-- !query 35 output
+1234567890123456
+
+
+-- !query 36
+select ceil(0.01)
+-- !query 36 schema
+struct
+-- !query 36 output
+1
+
+
+-- !query 37
+select ceiling(-0.10)
+-- !query 37 schema
+struct
+-- !query 37 output
+0
+
+
+-- !query 38
+select floor(0)
+-- !query 38 schema
+struct
+-- !query 38 output
+0
+
+
+-- !query 39
+select floor(1)
+-- !query 39 schema
+struct
+-- !query 39 output
+1
+
+
+-- !query 40
+select floor(1234567890123456)
+-- !query 40 schema
+struct
+-- !query 40 output
+1234567890123456
+
+
+-- !query 41
+select floor(0.01)
+-- !query 41 schema
+struct
+-- !query 41 output
+0
+
+
+-- !query 42
+select floor(-0.10)
+-- !query 42 schema
+struct
+-- !query 42 output
+-1
+
+
+-- !query 43
+select 1 > 0.00001
+-- !query 43 schema
+struct<(CAST(1 AS BIGINT) > 0):boolean>
+-- !query 43 output
+true
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
index bfa29d7d2d597..4e6353b1f332c 100644
--- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 22
+-- Number of queries: 24
-- !query 0
@@ -176,3 +176,24 @@ SELECT CAST('9223372036854775808' AS long)
struct
-- !query 21 output
NULL
+
+
+-- !query 22
+DESC FUNCTION boolean
+-- !query 22 schema
+struct
+-- !query 22 output
+Class: org.apache.spark.sql.catalyst.expressions.Cast
+Function: boolean
+Usage: boolean(expr) - Casts the value `expr` to the target data type `boolean`.
+
+
+-- !query 23
+DESC FUNCTION EXTENDED boolean
+-- !query 23 schema
+struct
+-- !query 23 output
+Class: org.apache.spark.sql.catalyst.expressions.Cast
+Extended Usage:N/A.
+Function: boolean
+Usage: boolean(expr) - Casts the value `expr` to the target data type `boolean`.
diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
index 678a3f0f0a3c6..ff1ecbcc44c23 100644
--- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 32
+-- Number of queries: 33
-- !query 0
@@ -15,7 +15,6 @@ DESC test_change
-- !query 1 schema
struct
-- !query 1 output
-# col_name data_type comment
a int
b string
c int
@@ -35,7 +34,6 @@ DESC test_change
-- !query 3 schema
struct
-- !query 3 output
-# col_name data_type comment
a int
b string
c int
@@ -55,7 +53,6 @@ DESC test_change
-- !query 5 schema
struct
-- !query 5 output
-# col_name data_type comment
a int
b string
c int
@@ -94,7 +91,6 @@ DESC test_change
-- !query 8 schema
struct
-- !query 8 output
-# col_name data_type comment
a int
b string
c int
@@ -129,7 +125,6 @@ DESC test_change
-- !query 12 schema
struct
-- !query 12 output
-# col_name data_type comment
a int this is column a
b string #*02?`
c int
@@ -148,7 +143,6 @@ DESC test_change
-- !query 14 schema
struct
-- !query 14 output
-# col_name data_type comment
a int this is column a
b string #*02?`
c int
@@ -160,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT
struct<>
-- !query 15 output
org.apache.spark.sql.AnalysisException
-Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))';
+Can't find column `invalid_col` given table data columns [`a`, `b`, `c`];
-- !query 16
@@ -168,7 +162,6 @@ DESC test_change
-- !query 16 schema
struct
-- !query 16 output
-# col_name data_type comment
a int this is column a
b string #*02?`
c int
@@ -193,7 +186,6 @@ DESC test_change
-- !query 18 schema
struct
-- !query 18 output
-# col_name data_type comment
a int this is column a
b string #*02?`
c int
@@ -237,7 +229,6 @@ DESC test_change
-- !query 23 schema
struct
-- !query 23 output
-# col_name data_type comment
a int this is column A
b string #*02?`
c int
@@ -300,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT
-- !query 30
-DROP TABLE test_change
+ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'
-- !query 30 schema
struct<>
-- !query 30 output
-
+org.apache.spark.sql.AnalysisException
+Can't find column `c` given table data columns [`a`, `b`];
-- !query 31
-DROP TABLE partition_table
+DROP TABLE test_change
-- !query 31 schema
struct<>
-- !query 31 output
+
+
+-- !query 32
+DROP TABLE partition_table
+-- !query 32 schema
+struct<>
+-- !query 32 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out
new file mode 100644
index 0000000000000..afc7b5448b7b6
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out
@@ -0,0 +1,18 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 2
+
+
+-- !query 0
+select x'00' < x'0f'
+-- !query 0 schema
+struct<(X'00' < X'0F'):boolean>
+-- !query 0 output
+true
+
+
+-- !query 1
+select x'00' < x'ff'
+-- !query 1 schema
+struct<(X'00' < X'FF'):boolean>
+-- !query 1 output
+true
diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
index 032e4258500fb..2d15718d3f8ec 100644
--- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 1
+-- Number of queries: 6
-- !query 0
@@ -8,3 +8,50 @@ select current_date = current_date(), current_timestamp = current_timestamp()
struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean>
-- !query 0 output
true true
+
+
+-- !query 1
+create temporary view ttf1 as select * from values
+ (1, 2),
+ (2, 3)
+ as ttf1(current_date, current_timestamp)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+-- !query 2
+select current_date, current_timestamp from ttf1
+-- !query 2 schema
+struct
+-- !query 2 output
+1 2
+2 3
+
+
+-- !query 3
+create temporary view ttf2 as select * from values
+ (1, 2),
+ (2, 3)
+ as ttf2(a, b)
+-- !query 3 schema
+struct<>
+-- !query 3 output
+
+
+-- !query 4
+select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2
+-- !query 4 schema
+struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean,a:int,b:int>
+-- !query 4 output
+true true 1 2
+true true 2 3
+
+
+-- !query 5
+select a, b from ttf2 order by a, current_date
+-- !query 5 schema
+struct
+-- !query 5 output
+1 2
+2 3
diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out
index de10b29f3c65b..ab9f2783f06bb 100644
--- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out
@@ -1,9 +1,10 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 31
+-- Number of queries: 32
-- !query 0
CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet
+ OPTIONS (a '1', b '2')
PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS
COMMENT 'table_comment'
-- !query 0 schema
@@ -42,7 +43,7 @@ struct<>
-- !query 4
-ALTER TABLE t ADD PARTITION (c='Us', d=1)
+ALTER TABLE t SET TBLPROPERTIES (e = '3')
-- !query 4 schema
struct<>
-- !query 4 output
@@ -50,11 +51,18 @@ struct<>
-- !query 5
-DESCRIBE t
+ALTER TABLE t ADD PARTITION (c='Us', d=1)
-- !query 5 schema
-struct
+struct<>
-- !query 5 output
-# col_name data_type comment
+
+
+
+-- !query 6
+DESCRIBE t
+-- !query 6 schema
+struct
+-- !query 6 output
a string
b int
c string
@@ -65,12 +73,11 @@ c string
d string
--- !query 6
+-- !query 7
DESC default.t
--- !query 6 schema
+-- !query 7 schema
struct
--- !query 6 output
-# col_name data_type comment
+-- !query 7 output
a string
b int
c string
@@ -81,12 +88,11 @@ c string
d string
--- !query 7
+-- !query 8
DESC TABLE t
--- !query 7 schema
+-- !query 8 schema
struct
--- !query 7 output
-# col_name data_type comment
+-- !query 8 output
a string
b int
c string
@@ -97,12 +103,11 @@ c string
d string
--- !query 8
+-- !query 9
DESC FORMATTED t
--- !query 8 schema
+-- !query 9 schema
struct
--- !query 8 output
-# col_name data_type comment
+-- !query 9 output
a string
b int
c string
@@ -123,16 +128,17 @@ Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
Comment table_comment
+Table Properties [e=3]
Location [not included in comparison]sql/core/spark-warehouse/t
+Storage Properties [a=1, b=2]
Partition Provider Catalog
--- !query 9
+-- !query 10
DESC EXTENDED t
--- !query 9 schema
+-- !query 10 schema
struct
--- !query 9 output
-# col_name data_type comment
+-- !query 10 output
a string
b int
c string
@@ -153,16 +159,17 @@ Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
Comment table_comment
+Table Properties [e=3]
Location [not included in comparison]sql/core/spark-warehouse/t
+Storage Properties [a=1, b=2]
Partition Provider Catalog
--- !query 10
+-- !query 11
DESC t PARTITION (c='Us', d=1)
--- !query 10 schema
+-- !query 11 schema
struct
--- !query 10 output
-# col_name data_type comment
+-- !query 11 output
a string
b int
c string
@@ -173,12 +180,11 @@ c string
d string
--- !query 11
+-- !query 12
DESC EXTENDED t PARTITION (c='Us', d=1)
--- !query 11 schema
+-- !query 12 schema
struct
--- !query 11 output
-# col_name data_type comment
+-- !query 12 output
a string
b int
c string
@@ -193,20 +199,21 @@ Database default
Table t
Partition Values [c=Us, d=1]
Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1
+Storage Properties [a=1, b=2]
# Storage Information
Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
-Location [not included in comparison]sql/core/spark-warehouse/t
+Location [not included in comparison]sql/core/spark-warehouse/t
+Storage Properties [a=1, b=2]
--- !query 12
+-- !query 13
DESC FORMATTED t PARTITION (c='Us', d=1)
--- !query 12 schema
+-- !query 13 schema
struct
--- !query 12 output
-# col_name data_type comment
+-- !query 13 output
a string
b int
c string
@@ -221,39 +228,41 @@ Database default
Table t
Partition Values [c=Us, d=1]
Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1
+Storage Properties [a=1, b=2]
# Storage Information
Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
-Location [not included in comparison]sql/core/spark-warehouse/t
+Location [not included in comparison]sql/core/spark-warehouse/t
+Storage Properties [a=1, b=2]
--- !query 13
+-- !query 14
DESC t PARTITION (c='Us', d=2)
--- !query 13 schema
+-- !query 14 schema
struct<>
--- !query 13 output
+-- !query 14 output
org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException
Partition not found in table 't' database 'default':
c -> Us
d -> 2;
--- !query 14
+-- !query 15
DESC t PARTITION (c='Us')
--- !query 14 schema
+-- !query 15 schema
struct<>
--- !query 14 output
+-- !query 15 output
org.apache.spark.sql.AnalysisException
Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`';
--- !query 15
+-- !query 16
DESC t PARTITION (c='Us', d)
--- !query 15 schema
+-- !query 16 schema
struct<>
--- !query 15 output
+-- !query 16 output
org.apache.spark.sql.catalyst.parser.ParseException
PARTITION specification is incomplete: `d`(line 1, pos 0)
@@ -263,24 +272,11 @@ DESC t PARTITION (c='Us', d)
^^^
--- !query 16
-DESC temp_v
--- !query 16 schema
-struct
--- !query 16 output
-# col_name data_type comment
-a string
-b int
-c string
-d string
-
-
-- !query 17
-DESC TABLE temp_v
+DESC temp_v
-- !query 17 schema
struct
-- !query 17 output
-# col_name data_type comment
a string
b int
c string
@@ -288,11 +284,10 @@ d string
-- !query 18
-DESC FORMATTED temp_v
+DESC TABLE temp_v
-- !query 18 schema
struct
-- !query 18 output
-# col_name data_type comment
a string
b int
c string
@@ -300,11 +295,10 @@ d string
-- !query 19
-DESC EXTENDED temp_v
+DESC FORMATTED temp_v
-- !query 19 schema
struct
-- !query 19 output
-# col_name data_type comment
a string
b int
c string
@@ -312,11 +306,21 @@ d string
-- !query 20
-DESC temp_Data_Source_View
+DESC EXTENDED temp_v
-- !query 20 schema
struct
-- !query 20 output
-# col_name data_type comment
+a string
+b int
+c string
+d string
+
+
+-- !query 21
+DESC temp_Data_Source_View
+-- !query 21 schema
+struct
+-- !query 21 output
intType int test comment test1
stringType string
dateType date
@@ -335,45 +339,42 @@ arrayType array
structType struct
--- !query 21
+-- !query 22
DESC temp_v PARTITION (c='Us', d=1)
--- !query 21 schema
+-- !query 22 schema
struct<>
--- !query 21 output
+-- !query 22 output
org.apache.spark.sql.AnalysisException
DESC PARTITION is not allowed on a temporary view: temp_v;
--- !query 22
+-- !query 23
DESC v
--- !query 22 schema
+-- !query 23 schema
struct
--- !query 22 output
-# col_name data_type comment
+-- !query 23 output
a string
b int
c string
d string
--- !query 23
+-- !query 24
DESC TABLE v
--- !query 23 schema
+-- !query 24 schema
struct
--- !query 23 output
-# col_name data_type comment
+-- !query 24 output
a string
b int
c string
d string
--- !query 24
+-- !query 25
DESC FORMATTED v
--- !query 24 schema
+-- !query 25 schema
struct
--- !query 24 output
-# col_name data_type comment
+-- !query 25 output
a string
b int
c string
@@ -388,15 +389,14 @@ Type VIEW
View Text SELECT * FROM t
View Default Database default
View Query Output Columns [a, b, c, d]
-Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c]
+Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c]
--- !query 25
+-- !query 26
DESC EXTENDED v
--- !query 25 schema
+-- !query 26 schema
struct
--- !query 25 output
-# col_name data_type comment
+-- !query 26 output
a string
b int
c string
@@ -411,28 +411,20 @@ Type VIEW
View Text SELECT * FROM t
View Default Database default
View Query Output Columns [a, b, c, d]
-Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c]
-
-
--- !query 26
-DESC v PARTITION (c='Us', d=1)
--- !query 26 schema
-struct<>
--- !query 26 output
-org.apache.spark.sql.AnalysisException
-DESC PARTITION is not allowed on a view: v;
+Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c]
-- !query 27
-DROP TABLE t
+DESC v PARTITION (c='Us', d=1)
-- !query 27 schema
struct<>
-- !query 27 output
-
+org.apache.spark.sql.AnalysisException
+DESC PARTITION is not allowed on a view: v;
-- !query 28
-DROP VIEW temp_v
+DROP TABLE t
-- !query 28 schema
struct<>
-- !query 28 output
@@ -440,7 +432,7 @@ struct<>
-- !query 29
-DROP VIEW temp_Data_Source_View
+DROP VIEW temp_v
-- !query 29 schema
struct<>
-- !query 29 output
@@ -448,8 +440,16 @@ struct<>
-- !query 30
-DROP VIEW v
+DROP VIEW temp_Data_Source_View
-- !query 30 schema
struct<>
-- !query 30 output
+
+
+-- !query 31
+DROP VIEW v
+-- !query 31 schema
+struct<>
+-- !query 31 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out
index 825e8f5488c8b..ce7a16a4d0c81 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 26
+-- Number of queries: 29
-- !query 0
@@ -328,3 +328,50 @@ struct<>
-- !query 25 output
org.apache.spark.sql.AnalysisException
grouping__id is deprecated; use grouping_id() instead;
+
+
+-- !query 26
+SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2)
+-- !query 26 schema
+struct
+-- !query 26 output
+2 1 0
+2 NULL 0
+3 1 1
+3 2 -1
+3 NULL 0
+4 1 2
+4 2 0
+4 NULL 2
+5 2 1
+5 NULL 1
+NULL 1 3
+NULL 2 0
+NULL NULL 3
+
+
+-- !query 27
+SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b)
+-- !query 27 schema
+struct
+-- !query 27 output
+2 1 0
+2 NULL 0
+3 1 1
+3 2 -1
+3 NULL 0
+4 1 2
+4 2 0
+4 NULL 2
+5 2 1
+5 NULL 1
+NULL NULL 3
+
+
+-- !query 28
+SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
+-- !query 28 schema
+struct<(a + b):int,k:int,sum((a - b)):bigint>
+-- !query 28 output
+NULL 1 3
+NULL 2 0
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
index c0930bbde69a4..9ecbe19078dd6 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 19
+-- Number of queries: 20
-- !query 0
@@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
struct<>
-- !query 11 output
org.apache.spark.sql.AnalysisException
-GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39
+aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT));
-- !query 12
@@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
-GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43
+aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT));
-- !query 13
@@ -173,16 +173,26 @@ struct
-- !query 17
-set spark.sql.groupByOrdinal=false
+select a, a AS k, count(b) from data group by k, 1
-- !query 17 schema
-struct
+struct
-- !query 17 output
-spark.sql.groupByOrdinal false
+1 1 2
+2 2 2
+3 3 2
-- !query 18
-select sum(b) from data group by -1
+set spark.sql.groupByOrdinal=false
-- !query 18 schema
-struct
+struct
-- !query 18 output
+spark.sql.groupByOrdinal false
+
+
+-- !query 19
+select sum(b) from data group by -1
+-- !query 19 schema
+struct
+-- !query 19 output
9
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index 4b87d5161fc0e..ed66c03a12081 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 15
+-- Number of queries: 26
-- !query 0
@@ -139,3 +139,105 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS
struct
-- !query 14 output
1 1
+
+
+-- !query 15
+SELECT a AS k, COUNT(b) FROM testData GROUP BY k
+-- !query 15 schema
+struct
+-- !query 15 output
+1 2
+2 2
+3 2
+NULL 1
+
+
+-- !query 16
+SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1
+-- !query 16 schema
+struct
+-- !query 16 output
+2 2
+3 2
+
+
+-- !query 17
+SELECT COUNT(b) AS k FROM testData GROUP BY k
+-- !query 17 schema
+struct<>
+-- !query 17 output
+org.apache.spark.sql.AnalysisException
+aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`);
+
+
+-- !query 18
+CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES
+(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v)
+-- !query 18 schema
+struct<>
+-- !query 18 output
+
+
+
+-- !query 19
+SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a
+-- !query 19 schema
+struct<>
+-- !query 19 output
+org.apache.spark.sql.AnalysisException
+expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
+
+
+-- !query 20
+set spark.sql.groupByAliases=false
+-- !query 20 schema
+struct
+-- !query 20 output
+spark.sql.groupByAliases false
+
+
+-- !query 21
+SELECT a AS k, COUNT(b) FROM testData GROUP BY k
+-- !query 21 schema
+struct<>
+-- !query 21 output
+org.apache.spark.sql.AnalysisException
+cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47
+
+
+-- !query 22
+SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a
+-- !query 22 schema
+struct
+-- !query 22 output
+
+
+
+-- !query 23
+SELECT COUNT(1) FROM testData WHERE false
+-- !query 23 schema
+struct
+-- !query 23 output
+0
+
+
+-- !query 24
+SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t
+-- !query 24 schema
+struct<1:int>
+-- !query 24 output
+1
+
+
+-- !query 25
+SELECT 1 from (
+ SELECT 1 AS z,
+ MIN(a.x)
+ FROM (select 1 as x) a
+ WHERE false
+) b
+where b.z != b.z
+-- !query 25 schema
+struct<1:int>
+-- !query 25 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index e0923832673cb..d87ee5221647f 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 4
+-- Number of queries: 5
-- !query 0
@@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0)
struct
-- !query 3 output
1
+
+
+-- !query 4
+SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1
+-- !query 4 schema
+struct<(a + CAST(b AS BIGINT)):bigint>
+-- !query 4 output
+3
+7
diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
index 315e1730ce7df..f569245a432b9 100644
--- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
@@ -24,7 +24,7 @@ Extended Usage:
{"a":1,"b":2}
> SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
{"time":"26/08/2015"}
- > SELECT to_json(array(named_struct('a', 1, 'b', 2));
+ > SELECT to_json(array(named_struct('a', 1, 'b', 2)));
[{"a":1,"b":2}]
Function: to_json
@@ -141,7 +141,7 @@ struct<>
-- !query 13 output
org.apache.spark.sql.AnalysisException
-DataType invalidtype() is not supported.(line 1, pos 2)
+DataType invalidtype is not supported.(line 1, pos 2)
== SQL ==
a InvalidType
diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
index 9f0b95994be53..e035505f15d28 100644
--- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 13
+-- Number of queries: 15
-- !query 0
@@ -88,7 +88,7 @@ Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nul
== Physical Plan ==
*Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x]
-+- *Range (0, 2, step=1, splits=None)
++- *Range (0, 2, step=1, splits=2)
-- !query 9
@@ -122,3 +122,19 @@ struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
Function string accepts only one argument; line 1 pos 7
+
+
+-- !query 13
+CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st)
+-- !query 13 schema
+struct<>
+-- !query 13 output
+
+
+
+-- !query 14
+SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value")
+-- !query 14 schema
+struct
+-- !query 14 output
+gamma 1
diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out
index 3e32f46195464..1da33bc736f0b 100644
--- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 6
+-- Number of queries: 9
-- !query 0
@@ -58,3 +58,33 @@ struct>
1 {"AA":"1","C":"gamma","D":"delta"}
2 {"AA":"2","C":"epsilon","D":"eta"}
3 {"AA":"3","C":"theta","D":"iota"}
+
+
+-- !query 6
+SELECT ID, STRUCT(ST.*).C NST FROM tbl_x
+-- !query 6 schema
+struct
+-- !query 6 output
+1 gamma
+2 epsilon
+3 theta
+
+
+-- !query 7
+SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x
+-- !query 7 schema
+struct
+-- !query 7 output
+1 delta
+2 eta
+3 iota
+
+
+-- !query 8
+SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x
+-- !query 8 schema
+struct
+-- !query 8 output
+1 delta
+2 eta
+3 iota
diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
index 8b29300e71f90..a2b86db3e4f4c 100644
--- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 26
+-- Number of queries: 29
-- !query 0
@@ -293,6 +293,21 @@ val1d
-- !query 19
+SELECT t1a
+FROM t1
+WHERE t1a = (SELECT max(t2a)
+ FROM t2
+ WHERE t2c = t1c
+ GROUP BY t2c
+ HAVING count(*) >= 1)
+OR t1i > '2014-12-31'
+-- !query 19 schema
+struct
+-- !query 19 output
+val1c
+val1d
+
+-- !query 22
SELECT count(t1a)
FROM t1 RIGHT JOIN t2
ON t1d = t2d
@@ -300,13 +315,13 @@ WHERE t1a < (SELECT max(t2a)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 19 schema
+-- !query 22 schema
struct
--- !query 19 output
+-- !query 22 output
7
--- !query 20
+-- !query 23
SELECT t1a
FROM t1
WHERE t1b <= (SELECT max(t2b)
@@ -317,14 +332,14 @@ AND t1b >= (SELECT min(t2b)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 20 schema
+-- !query 23 schema
struct
--- !query 20 output
+-- !query 23 output
val1b
val1c
--- !query 21
+-- !query 24
SELECT t1a
FROM t1
WHERE t1a <= (SELECT max(t2a)
@@ -338,14 +353,14 @@ WHERE t1a >= (SELECT min(t2a)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 21 schema
+-- !query 24 schema
struct
--- !query 21 output
+-- !query 24 output
val1b
val1c
--- !query 22
+-- !query 25
SELECT t1a
FROM t1
WHERE t1a <= (SELECT max(t2a)
@@ -359,9 +374,9 @@ WHERE t1a >= (SELECT min(t2a)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 22 schema
+-- !query 25 schema
struct
--- !query 22 output
+-- !query 25 output
val1a
val1a
val1b
@@ -372,7 +387,7 @@ val1d
val1d
--- !query 23
+-- !query 26
SELECT t1a
FROM t1
WHERE t1a <= (SELECT max(t2a)
@@ -386,16 +401,16 @@ WHERE t1a >= (SELECT min(t2a)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 23 schema
+-- !query 26 schema
struct
--- !query 23 output
+-- !query 26 output
val1a
val1b
val1c
val1d
--- !query 24
+-- !query 27
SELECT t1a
FROM t1
WHERE t1a <= (SELECT max(t2a)
@@ -409,13 +424,13 @@ WHERE t1a >= (SELECT min(t2a)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 24 schema
+-- !query 27 schema
struct
--- !query 24 output
+-- !query 27 output
val1a
--- !query 25
+-- !query 28
SELECT t1a
FROM t1
GROUP BY t1a, t1c
@@ -423,8 +438,8 @@ HAVING max(t1b) <= (SELECT max(t2b)
FROM t2
WHERE t2c = t1c
GROUP BY t2c)
--- !query 25 schema
+-- !query 28 schema
struct
--- !query 25 output
+-- !query 28 output
val1b
val1c
diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
index acd4ecf14617e..e2ee970d35f60 100644
--- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -102,4 +102,4 @@ EXPLAIN select * from RaNgE(2)
struct
-- !query 8 output
== Physical Plan ==
-*Range (0, 2, step=1, splits=None)
+*Range (0, 2, step=1, splits=2)
diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out
new file mode 100644
index 0000000000000..aa5856138ed81
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out
@@ -0,0 +1,204 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 11
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
+(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null)
+AS testData(val, cate)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData
+ORDER BY cate, val
+-- !query 1 schema
+struct
+-- !query 1 output
+NULL NULL 0
+3 NULL 1
+NULL a 0
+1 a 1
+1 a 1
+2 a 1
+1 b 1
+2 b 1
+3 b 1
+
+
+-- !query 2
+SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val
+ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val
+-- !query 2 schema
+struct
+-- !query 2 output
+NULL NULL 3
+3 NULL 3
+NULL a 1
+1 a 2
+1 a 4
+2 a 4
+1 b 3
+2 b 6
+3 b 6
+
+
+-- !query 3
+SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData
+ORDER BY cate, val
+-- !query 3 schema
+struct
+-- !query 3 output
+NULL NULL 0
+3 NULL 1
+NULL a 0
+1 a 2
+1 a 2
+2 a 3
+1 b 1
+2 b 2
+3 b 2
+
+
+-- !query 4
+SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val
+RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val
+-- !query 4 schema
+struct
+-- !query 4 output
+NULL NULL NULL
+3 NULL 3
+NULL a NULL
+1 a 4
+1 a 4
+2 a 2
+1 b 3
+2 b 5
+3 b 3
+
+
+-- !query 5
+SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC
+RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val
+-- !query 5 schema
+struct
+-- !query 5 output
+NULL NULL NULL
+3 NULL 3
+NULL a NULL
+1 a 2
+1 a 2
+2 a 4
+1 b 1
+2 b 3
+3 b 5
+
+
+-- !query 6
+SELECT val, cate,
+max(val) OVER w AS max,
+min(val) OVER w AS min,
+min(val) OVER w AS min,
+count(val) OVER w AS count,
+sum(val) OVER w AS sum,
+avg(val) OVER w AS avg,
+stddev(val) OVER w AS stddev,
+first_value(val) OVER w AS first_value,
+first_value(val, true) OVER w AS first_value_ignore_null,
+first_value(val, false) OVER w AS first_value_contain_null,
+last_value(val) OVER w AS last_value,
+last_value(val, true) OVER w AS last_value_ignore_null,
+last_value(val, false) OVER w AS last_value_contain_null,
+rank() OVER w AS rank,
+dense_rank() OVER w AS dense_rank,
+cume_dist() OVER w AS cume_dist,
+percent_rank() OVER w AS percent_rank,
+ntile(2) OVER w AS ntile,
+row_number() OVER w AS row_number,
+var_pop(val) OVER w AS var_pop,
+var_samp(val) OVER w AS var_samp,
+approx_count_distinct(val) OVER w AS approx_count_distinct
+FROM testData
+WINDOW w AS (PARTITION BY cate ORDER BY val)
+ORDER BY cate, val
+-- !query 6 schema
+struct
+-- !query 6 output
+NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0
+3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1
+NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0
+1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1
+1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1
+2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2
+1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1
+2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2
+3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3
+
+
+-- !query 7
+SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val
+-- !query 7 schema
+struct
+-- !query 7 output
+NULL NULL NULL
+3 NULL NULL
+NULL a NULL
+1 a NULL
+1 a NULL
+2 a NULL
+1 b NULL
+2 b NULL
+3 b NULL
+
+
+-- !query 8
+SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val
+-- !query 8 schema
+struct<>
+-- !query 8 output
+org.apache.spark.sql.AnalysisException
+Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table;
+
+
+-- !query 9
+SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val
+-- !query 9 schema
+struct
+-- !query 9 output
+NULL NULL 13 1.8571428571428572
+3 NULL 13 1.8571428571428572
+NULL a 13 1.8571428571428572
+1 a 13 1.8571428571428572
+1 a 13 1.8571428571428572
+2 a 13 1.8571428571428572
+1 b 13 1.8571428571428572
+2 b 13 1.8571428571428572
+3 b 13 1.8571428571428572
+
+
+-- !query 10
+SELECT val, cate,
+first_value(false) OVER w AS first_value,
+first_value(true, true) OVER w AS first_value_ignore_null,
+first_value(false, false) OVER w AS first_value_contain_null,
+last_value(false) OVER w AS last_value,
+last_value(true, true) OVER w AS last_value_ignore_null,
+last_value(false, false) OVER w AS last_value_contain_null
+FROM testData
+WINDOW w AS ()
+ORDER BY cate, val
+-- !query 10 schema
+struct
+-- !query 10 output
+NULL NULL false true false false true false
+3 NULL false true false false true false
+NULL a false true false false true false
+1 a false true false false true false
+1 a false true false false true false
+2 a false true false false true false
+1 b false true false false true false
+2 b false true false false true false
+3 b false true false false true false
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql
new file mode 100755
index 0000000000000..79dd3d516e8c7
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql
@@ -0,0 +1,70 @@
+-- start query 10 in stream 0 using template query10.tpl
+with
+v1 as (
+ select
+ ws_bill_customer_sk as customer_sk
+ from web_sales,
+ date_dim
+ where ws_sold_date_sk = d_date_sk
+ and d_year = 2002
+ and d_moy between 4 and 4+3
+ union all
+ select
+ cs_ship_customer_sk as customer_sk
+ from catalog_sales,
+ date_dim
+ where cs_sold_date_sk = d_date_sk
+ and d_year = 2002
+ and d_moy between 4 and 4+3
+),
+v2 as (
+ select
+ ss_customer_sk as customer_sk
+ from store_sales,
+ date_dim
+ where ss_sold_date_sk = d_date_sk
+ and d_year = 2002
+ and d_moy between 4 and 4+3
+)
+select
+ cd_gender,
+ cd_marital_status,
+ cd_education_status,
+ count(*) cnt1,
+ cd_purchase_estimate,
+ count(*) cnt2,
+ cd_credit_rating,
+ count(*) cnt3,
+ cd_dep_count,
+ count(*) cnt4,
+ cd_dep_employed_count,
+ count(*) cnt5,
+ cd_dep_college_count,
+ count(*) cnt6
+from customer c
+join customer_address ca on (c.c_current_addr_sk = ca.ca_address_sk)
+join customer_demographics on (cd_demo_sk = c.c_current_cdemo_sk)
+left semi join v1 on (v1.customer_sk = c.c_customer_sk)
+left semi join v2 on (v2.customer_sk = c.c_customer_sk)
+where
+ ca_county in ('Walker County','Richland County','Gaines County','Douglas County','Dona Ana County')
+group by
+ cd_gender,
+ cd_marital_status,
+ cd_education_status,
+ cd_purchase_estimate,
+ cd_credit_rating,
+ cd_dep_count,
+ cd_dep_employed_count,
+ cd_dep_college_count
+order by
+ cd_gender,
+ cd_marital_status,
+ cd_education_status,
+ cd_purchase_estimate,
+ cd_credit_rating,
+ cd_dep_count,
+ cd_dep_employed_count,
+ cd_dep_college_count
+limit 100
+-- end query 10 in stream 0 using template query10.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql
new file mode 100755
index 0000000000000..1799827762916
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql
@@ -0,0 +1,38 @@
+-- start query 19 in stream 0 using template query19.tpl
+select
+ i_brand_id brand_id,
+ i_brand brand,
+ i_manufact_id,
+ i_manufact,
+ sum(ss_ext_sales_price) ext_price
+from
+ date_dim,
+ store_sales,
+ item,
+ customer,
+ customer_address,
+ store
+where
+ d_date_sk = ss_sold_date_sk
+ and ss_item_sk = i_item_sk
+ and i_manager_id = 7
+ and d_moy = 11
+ and d_year = 1999
+ and ss_customer_sk = c_customer_sk
+ and c_current_addr_sk = ca_address_sk
+ and substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5)
+ and ss_store_sk = s_store_sk
+ and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter
+group by
+ i_brand,
+ i_brand_id,
+ i_manufact_id,
+ i_manufact
+order by
+ ext_price desc,
+ i_brand,
+ i_brand_id,
+ i_manufact_id,
+ i_manufact
+limit 100
+-- end query 19 in stream 0 using template query19.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql
new file mode 100755
index 0000000000000..dedbc62a2ab2e
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql
@@ -0,0 +1,43 @@
+-- start query 27 in stream 0 using template query27.tpl
+ with results as
+ (select i_item_id,
+ s_state,
+ ss_quantity agg1,
+ ss_list_price agg2,
+ ss_coupon_amt agg3,
+ ss_sales_price agg4
+ --0 as g_state,
+ --avg(ss_quantity) agg1,
+ --avg(ss_list_price) agg2,
+ --avg(ss_coupon_amt) agg3,
+ --avg(ss_sales_price) agg4
+ from store_sales, customer_demographics, date_dim, store, item
+ where ss_sold_date_sk = d_date_sk and
+ ss_sold_date_sk between 2451545 and 2451910 and
+ ss_item_sk = i_item_sk and
+ ss_store_sk = s_store_sk and
+ ss_cdemo_sk = cd_demo_sk and
+ cd_gender = 'F' and
+ cd_marital_status = 'D' and
+ cd_education_status = 'Primary' and
+ d_year = 2000 and
+ s_state in ('TN','AL', 'SD', 'SD', 'SD', 'SD')
+ --group by i_item_id, s_state
+ )
+
+ select i_item_id,
+ s_state, g_state, agg1, agg2, agg3, agg4
+ from (
+ select i_item_id, s_state, 0 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, avg(agg4) agg4 from results
+ group by i_item_id, s_state
+ union all
+ select i_item_id, NULL AS s_state, 1 AS g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3,
+ avg(agg4) agg4 from results
+ group by i_item_id
+ union all
+ select NULL AS i_item_id, NULL as s_state, 1 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3,
+ avg(agg4) agg4 from results
+ ) foo
+ order by i_item_id, s_state
+ limit 100
+-- end query 27 in stream 0 using template query27.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql
new file mode 100755
index 0000000000000..35b0a20f80a4e
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql
@@ -0,0 +1,228 @@
+-- start query 3 in stream 0 using template query3.tpl
+select
+ dt.d_year,
+ item.i_brand_id brand_id,
+ item.i_brand brand,
+ sum(ss_net_profit) sum_agg
+from
+ date_dim dt,
+ store_sales,
+ item
+where
+ dt.d_date_sk = store_sales.ss_sold_date_sk
+ and store_sales.ss_item_sk = item.i_item_sk
+ and item.i_manufact_id = 436
+ and dt.d_moy = 12
+ -- partition key filters
+ and (
+ss_sold_date_sk between 2415355 and 2415385
+or ss_sold_date_sk between 2415720 and 2415750
+or ss_sold_date_sk between 2416085 and 2416115
+or ss_sold_date_sk between 2416450 and 2416480
+or ss_sold_date_sk between 2416816 and 2416846
+or ss_sold_date_sk between 2417181 and 2417211
+or ss_sold_date_sk between 2417546 and 2417576
+or ss_sold_date_sk between 2417911 and 2417941
+or ss_sold_date_sk between 2418277 and 2418307
+or ss_sold_date_sk between 2418642 and 2418672
+or ss_sold_date_sk between 2419007 and 2419037
+or ss_sold_date_sk between 2419372 and 2419402
+or ss_sold_date_sk between 2419738 and 2419768
+or ss_sold_date_sk between 2420103 and 2420133
+or ss_sold_date_sk between 2420468 and 2420498
+or ss_sold_date_sk between 2420833 and 2420863
+or ss_sold_date_sk between 2421199 and 2421229
+or ss_sold_date_sk between 2421564 and 2421594
+or ss_sold_date_sk between 2421929 and 2421959
+or ss_sold_date_sk between 2422294 and 2422324
+or ss_sold_date_sk between 2422660 and 2422690
+or ss_sold_date_sk between 2423025 and 2423055
+or ss_sold_date_sk between 2423390 and 2423420
+or ss_sold_date_sk between 2423755 and 2423785
+or ss_sold_date_sk between 2424121 and 2424151
+or ss_sold_date_sk between 2424486 and 2424516
+or ss_sold_date_sk between 2424851 and 2424881
+or ss_sold_date_sk between 2425216 and 2425246
+or ss_sold_date_sk between 2425582 and 2425612
+or ss_sold_date_sk between 2425947 and 2425977
+or ss_sold_date_sk between 2426312 and 2426342
+or ss_sold_date_sk between 2426677 and 2426707
+or ss_sold_date_sk between 2427043 and 2427073
+or ss_sold_date_sk between 2427408 and 2427438
+or ss_sold_date_sk between 2427773 and 2427803
+or ss_sold_date_sk between 2428138 and 2428168
+or ss_sold_date_sk between 2428504 and 2428534
+or ss_sold_date_sk between 2428869 and 2428899
+or ss_sold_date_sk between 2429234 and 2429264
+or ss_sold_date_sk between 2429599 and 2429629
+or ss_sold_date_sk between 2429965 and 2429995
+or ss_sold_date_sk between 2430330 and 2430360
+or ss_sold_date_sk between 2430695 and 2430725
+or ss_sold_date_sk between 2431060 and 2431090
+or ss_sold_date_sk between 2431426 and 2431456
+or ss_sold_date_sk between 2431791 and 2431821
+or ss_sold_date_sk between 2432156 and 2432186
+or ss_sold_date_sk between 2432521 and 2432551
+or ss_sold_date_sk between 2432887 and 2432917
+or ss_sold_date_sk between 2433252 and 2433282
+or ss_sold_date_sk between 2433617 and 2433647
+or ss_sold_date_sk between 2433982 and 2434012
+or ss_sold_date_sk between 2434348 and 2434378
+or ss_sold_date_sk between 2434713 and 2434743
+or ss_sold_date_sk between 2435078 and 2435108
+or ss_sold_date_sk between 2435443 and 2435473
+or ss_sold_date_sk between 2435809 and 2435839
+or ss_sold_date_sk between 2436174 and 2436204
+or ss_sold_date_sk between 2436539 and 2436569
+or ss_sold_date_sk between 2436904 and 2436934
+or ss_sold_date_sk between 2437270 and 2437300
+or ss_sold_date_sk between 2437635 and 2437665
+or ss_sold_date_sk between 2438000 and 2438030
+or ss_sold_date_sk between 2438365 and 2438395
+or ss_sold_date_sk between 2438731 and 2438761
+or ss_sold_date_sk between 2439096 and 2439126
+or ss_sold_date_sk between 2439461 and 2439491
+or ss_sold_date_sk between 2439826 and 2439856
+or ss_sold_date_sk between 2440192 and 2440222
+or ss_sold_date_sk between 2440557 and 2440587
+or ss_sold_date_sk between 2440922 and 2440952
+or ss_sold_date_sk between 2441287 and 2441317
+or ss_sold_date_sk between 2441653 and 2441683
+or ss_sold_date_sk between 2442018 and 2442048
+or ss_sold_date_sk between 2442383 and 2442413
+or ss_sold_date_sk between 2442748 and 2442778
+or ss_sold_date_sk between 2443114 and 2443144
+or ss_sold_date_sk between 2443479 and 2443509
+or ss_sold_date_sk between 2443844 and 2443874
+or ss_sold_date_sk between 2444209 and 2444239
+or ss_sold_date_sk between 2444575 and 2444605
+or ss_sold_date_sk between 2444940 and 2444970
+or ss_sold_date_sk between 2445305 and 2445335
+or ss_sold_date_sk between 2445670 and 2445700
+or ss_sold_date_sk between 2446036 and 2446066
+or ss_sold_date_sk between 2446401 and 2446431
+or ss_sold_date_sk between 2446766 and 2446796
+or ss_sold_date_sk between 2447131 and 2447161
+or ss_sold_date_sk between 2447497 and 2447527
+or ss_sold_date_sk between 2447862 and 2447892
+or ss_sold_date_sk between 2448227 and 2448257
+or ss_sold_date_sk between 2448592 and 2448622
+or ss_sold_date_sk between 2448958 and 2448988
+or ss_sold_date_sk between 2449323 and 2449353
+or ss_sold_date_sk between 2449688 and 2449718
+or ss_sold_date_sk between 2450053 and 2450083
+or ss_sold_date_sk between 2450419 and 2450449
+or ss_sold_date_sk between 2450784 and 2450814
+or ss_sold_date_sk between 2451149 and 2451179
+or ss_sold_date_sk between 2451514 and 2451544
+or ss_sold_date_sk between 2451880 and 2451910
+or ss_sold_date_sk between 2452245 and 2452275
+or ss_sold_date_sk between 2452610 and 2452640
+or ss_sold_date_sk between 2452975 and 2453005
+or ss_sold_date_sk between 2453341 and 2453371
+or ss_sold_date_sk between 2453706 and 2453736
+or ss_sold_date_sk between 2454071 and 2454101
+or ss_sold_date_sk between 2454436 and 2454466
+or ss_sold_date_sk between 2454802 and 2454832
+or ss_sold_date_sk between 2455167 and 2455197
+or ss_sold_date_sk between 2455532 and 2455562
+or ss_sold_date_sk between 2455897 and 2455927
+or ss_sold_date_sk between 2456263 and 2456293
+or ss_sold_date_sk between 2456628 and 2456658
+or ss_sold_date_sk between 2456993 and 2457023
+or ss_sold_date_sk between 2457358 and 2457388
+or ss_sold_date_sk between 2457724 and 2457754
+or ss_sold_date_sk between 2458089 and 2458119
+or ss_sold_date_sk between 2458454 and 2458484
+or ss_sold_date_sk between 2458819 and 2458849
+or ss_sold_date_sk between 2459185 and 2459215
+or ss_sold_date_sk between 2459550 and 2459580
+or ss_sold_date_sk between 2459915 and 2459945
+or ss_sold_date_sk between 2460280 and 2460310
+or ss_sold_date_sk between 2460646 and 2460676
+or ss_sold_date_sk between 2461011 and 2461041
+or ss_sold_date_sk between 2461376 and 2461406
+or ss_sold_date_sk between 2461741 and 2461771
+or ss_sold_date_sk between 2462107 and 2462137
+or ss_sold_date_sk between 2462472 and 2462502
+or ss_sold_date_sk between 2462837 and 2462867
+or ss_sold_date_sk between 2463202 and 2463232
+or ss_sold_date_sk between 2463568 and 2463598
+or ss_sold_date_sk between 2463933 and 2463963
+or ss_sold_date_sk between 2464298 and 2464328
+or ss_sold_date_sk between 2464663 and 2464693
+or ss_sold_date_sk between 2465029 and 2465059
+or ss_sold_date_sk between 2465394 and 2465424
+or ss_sold_date_sk between 2465759 and 2465789
+or ss_sold_date_sk between 2466124 and 2466154
+or ss_sold_date_sk between 2466490 and 2466520
+or ss_sold_date_sk between 2466855 and 2466885
+or ss_sold_date_sk between 2467220 and 2467250
+or ss_sold_date_sk between 2467585 and 2467615
+or ss_sold_date_sk between 2467951 and 2467981
+or ss_sold_date_sk between 2468316 and 2468346
+or ss_sold_date_sk between 2468681 and 2468711
+or ss_sold_date_sk between 2469046 and 2469076
+or ss_sold_date_sk between 2469412 and 2469442
+or ss_sold_date_sk between 2469777 and 2469807
+or ss_sold_date_sk between 2470142 and 2470172
+or ss_sold_date_sk between 2470507 and 2470537
+or ss_sold_date_sk between 2470873 and 2470903
+or ss_sold_date_sk between 2471238 and 2471268
+or ss_sold_date_sk between 2471603 and 2471633
+or ss_sold_date_sk between 2471968 and 2471998
+or ss_sold_date_sk between 2472334 and 2472364
+or ss_sold_date_sk between 2472699 and 2472729
+or ss_sold_date_sk between 2473064 and 2473094
+or ss_sold_date_sk between 2473429 and 2473459
+or ss_sold_date_sk between 2473795 and 2473825
+or ss_sold_date_sk between 2474160 and 2474190
+or ss_sold_date_sk between 2474525 and 2474555
+or ss_sold_date_sk between 2474890 and 2474920
+or ss_sold_date_sk between 2475256 and 2475286
+or ss_sold_date_sk between 2475621 and 2475651
+or ss_sold_date_sk between 2475986 and 2476016
+or ss_sold_date_sk between 2476351 and 2476381
+or ss_sold_date_sk between 2476717 and 2476747
+or ss_sold_date_sk between 2477082 and 2477112
+or ss_sold_date_sk between 2477447 and 2477477
+or ss_sold_date_sk between 2477812 and 2477842
+or ss_sold_date_sk between 2478178 and 2478208
+or ss_sold_date_sk between 2478543 and 2478573
+or ss_sold_date_sk between 2478908 and 2478938
+or ss_sold_date_sk between 2479273 and 2479303
+or ss_sold_date_sk between 2479639 and 2479669
+or ss_sold_date_sk between 2480004 and 2480034
+or ss_sold_date_sk between 2480369 and 2480399
+or ss_sold_date_sk between 2480734 and 2480764
+or ss_sold_date_sk between 2481100 and 2481130
+or ss_sold_date_sk between 2481465 and 2481495
+or ss_sold_date_sk between 2481830 and 2481860
+or ss_sold_date_sk between 2482195 and 2482225
+or ss_sold_date_sk between 2482561 and 2482591
+or ss_sold_date_sk between 2482926 and 2482956
+or ss_sold_date_sk between 2483291 and 2483321
+or ss_sold_date_sk between 2483656 and 2483686
+or ss_sold_date_sk between 2484022 and 2484052
+or ss_sold_date_sk between 2484387 and 2484417
+or ss_sold_date_sk between 2484752 and 2484782
+or ss_sold_date_sk between 2485117 and 2485147
+or ss_sold_date_sk between 2485483 and 2485513
+or ss_sold_date_sk between 2485848 and 2485878
+or ss_sold_date_sk between 2486213 and 2486243
+or ss_sold_date_sk between 2486578 and 2486608
+or ss_sold_date_sk between 2486944 and 2486974
+or ss_sold_date_sk between 2487309 and 2487339
+or ss_sold_date_sk between 2487674 and 2487704
+or ss_sold_date_sk between 2488039 and 2488069
+)
+group by
+ dt.d_year,
+ item.i_brand,
+ item.i_brand_id
+order by
+ dt.d_year,
+ sum_agg desc,
+ brand_id
+limit 100
+-- end query 3 in stream 0 using template query3.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql
new file mode 100755
index 0000000000000..d11696e5e0c34
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql
@@ -0,0 +1,45 @@
+-- start query 34 in stream 0 using template query34.tpl
+select
+ c_last_name,
+ c_first_name,
+ c_salutation,
+ c_preferred_cust_flag,
+ ss_ticket_number,
+ cnt
+from
+ (select
+ ss_ticket_number,
+ ss_customer_sk,
+ count(*) cnt
+ from
+ store_sales,
+ date_dim,
+ store,
+ household_demographics
+ where
+ store_sales.ss_sold_date_sk = date_dim.d_date_sk
+ and store_sales.ss_store_sk = store.s_store_sk
+ and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk
+ and (date_dim.d_dom between 1 and 3
+ or date_dim.d_dom between 25 and 28)
+ and (household_demographics.hd_buy_potential = '>10000'
+ or household_demographics.hd_buy_potential = 'Unknown')
+ and household_demographics.hd_vehicle_count > 0
+ and (case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end) > 1.2
+ and date_dim.d_year in (1998, 1998 + 1, 1998 + 2)
+ and store.s_county in ('Saginaw County', 'Sumner County', 'Appanoose County', 'Daviess County', 'Fairfield County', 'Raleigh County', 'Ziebach County', 'Williamson County')
+ and ss_sold_date_sk between 2450816 and 2451910 -- partition key filter
+ group by
+ ss_ticket_number,
+ ss_customer_sk
+ ) dn,
+ customer
+where
+ ss_customer_sk = c_customer_sk
+ and cnt between 15 and 20
+order by
+ c_last_name,
+ c_first_name,
+ c_salutation,
+ c_preferred_cust_flag desc
+-- end query 34 in stream 0 using template query34.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql
new file mode 100755
index 0000000000000..b6332a8afbebe
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql
@@ -0,0 +1,28 @@
+-- start query 42 in stream 0 using template query42.tpl
+select
+ dt.d_year,
+ item.i_category_id,
+ item.i_category,
+ sum(ss_ext_sales_price)
+from
+ date_dim dt,
+ store_sales,
+ item
+where
+ dt.d_date_sk = store_sales.ss_sold_date_sk
+ and store_sales.ss_item_sk = item.i_item_sk
+ and item.i_manager_id = 1
+ and dt.d_moy = 12
+ and dt.d_year = 1998
+ and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter
+group by
+ dt.d_year,
+ item.i_category_id,
+ item.i_category
+order by
+ sum(ss_ext_sales_price) desc,
+ dt.d_year,
+ item.i_category_id,
+ item.i_category
+limit 100
+-- end query 42 in stream 0 using template query42.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql
new file mode 100755
index 0000000000000..cc2040b2fdb7c
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql
@@ -0,0 +1,36 @@
+-- start query 43 in stream 0 using template query43.tpl
+select
+ s_store_name,
+ s_store_id,
+ sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales,
+ sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales,
+ sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales,
+ sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales,
+ sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales,
+ sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales,
+ sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales
+from
+ date_dim,
+ store_sales,
+ store
+where
+ d_date_sk = ss_sold_date_sk
+ and s_store_sk = ss_store_sk
+ and s_gmt_offset = -5
+ and d_year = 1998
+ and ss_sold_date_sk between 2450816 and 2451179 -- partition key filter
+group by
+ s_store_name,
+ s_store_id
+order by
+ s_store_name,
+ s_store_id,
+ sun_sales,
+ mon_sales,
+ tue_sales,
+ wed_sales,
+ thu_sales,
+ fri_sales,
+ sat_sales
+limit 100
+-- end query 43 in stream 0 using template query43.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql
new file mode 100755
index 0000000000000..52b7ba4f4b86b
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql
@@ -0,0 +1,80 @@
+-- start query 46 in stream 0 using template query46.tpl
+select
+ c_last_name,
+ c_first_name,
+ ca_city,
+ bought_city,
+ ss_ticket_number,
+ amt,
+ profit
+from
+ (select
+ ss_ticket_number,
+ ss_customer_sk,
+ ca_city bought_city,
+ sum(ss_coupon_amt) amt,
+ sum(ss_net_profit) profit
+ from
+ store_sales,
+ date_dim,
+ store,
+ household_demographics,
+ customer_address
+ where
+ store_sales.ss_sold_date_sk = date_dim.d_date_sk
+ and store_sales.ss_store_sk = store.s_store_sk
+ and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk
+ and store_sales.ss_addr_sk = customer_address.ca_address_sk
+ and (household_demographics.hd_dep_count = 5
+ or household_demographics.hd_vehicle_count = 3)
+ and date_dim.d_dow in (6, 0)
+ and date_dim.d_year in (1999, 1999 + 1, 1999 + 2)
+ and store.s_city in ('Midway', 'Concord', 'Spring Hill', 'Brownsville', 'Greenville')
+ -- partition key filter
+ and ss_sold_date_sk in (2451181, 2451182, 2451188, 2451189, 2451195, 2451196, 2451202, 2451203, 2451209, 2451210, 2451216, 2451217,
+ 2451223, 2451224, 2451230, 2451231, 2451237, 2451238, 2451244, 2451245, 2451251, 2451252, 2451258, 2451259,
+ 2451265, 2451266, 2451272, 2451273, 2451279, 2451280, 2451286, 2451287, 2451293, 2451294, 2451300, 2451301,
+ 2451307, 2451308, 2451314, 2451315, 2451321, 2451322, 2451328, 2451329, 2451335, 2451336, 2451342, 2451343,
+ 2451349, 2451350, 2451356, 2451357, 2451363, 2451364, 2451370, 2451371, 2451377, 2451378, 2451384, 2451385,
+ 2451391, 2451392, 2451398, 2451399, 2451405, 2451406, 2451412, 2451413, 2451419, 2451420, 2451426, 2451427,
+ 2451433, 2451434, 2451440, 2451441, 2451447, 2451448, 2451454, 2451455, 2451461, 2451462, 2451468, 2451469,
+ 2451475, 2451476, 2451482, 2451483, 2451489, 2451490, 2451496, 2451497, 2451503, 2451504, 2451510, 2451511,
+ 2451517, 2451518, 2451524, 2451525, 2451531, 2451532, 2451538, 2451539, 2451545, 2451546, 2451552, 2451553,
+ 2451559, 2451560, 2451566, 2451567, 2451573, 2451574, 2451580, 2451581, 2451587, 2451588, 2451594, 2451595,
+ 2451601, 2451602, 2451608, 2451609, 2451615, 2451616, 2451622, 2451623, 2451629, 2451630, 2451636, 2451637,
+ 2451643, 2451644, 2451650, 2451651, 2451657, 2451658, 2451664, 2451665, 2451671, 2451672, 2451678, 2451679,
+ 2451685, 2451686, 2451692, 2451693, 2451699, 2451700, 2451706, 2451707, 2451713, 2451714, 2451720, 2451721,
+ 2451727, 2451728, 2451734, 2451735, 2451741, 2451742, 2451748, 2451749, 2451755, 2451756, 2451762, 2451763,
+ 2451769, 2451770, 2451776, 2451777, 2451783, 2451784, 2451790, 2451791, 2451797, 2451798, 2451804, 2451805,
+ 2451811, 2451812, 2451818, 2451819, 2451825, 2451826, 2451832, 2451833, 2451839, 2451840, 2451846, 2451847,
+ 2451853, 2451854, 2451860, 2451861, 2451867, 2451868, 2451874, 2451875, 2451881, 2451882, 2451888, 2451889,
+ 2451895, 2451896, 2451902, 2451903, 2451909, 2451910, 2451916, 2451917, 2451923, 2451924, 2451930, 2451931,
+ 2451937, 2451938, 2451944, 2451945, 2451951, 2451952, 2451958, 2451959, 2451965, 2451966, 2451972, 2451973,
+ 2451979, 2451980, 2451986, 2451987, 2451993, 2451994, 2452000, 2452001, 2452007, 2452008, 2452014, 2452015,
+ 2452021, 2452022, 2452028, 2452029, 2452035, 2452036, 2452042, 2452043, 2452049, 2452050, 2452056, 2452057,
+ 2452063, 2452064, 2452070, 2452071, 2452077, 2452078, 2452084, 2452085, 2452091, 2452092, 2452098, 2452099,
+ 2452105, 2452106, 2452112, 2452113, 2452119, 2452120, 2452126, 2452127, 2452133, 2452134, 2452140, 2452141,
+ 2452147, 2452148, 2452154, 2452155, 2452161, 2452162, 2452168, 2452169, 2452175, 2452176, 2452182, 2452183,
+ 2452189, 2452190, 2452196, 2452197, 2452203, 2452204, 2452210, 2452211, 2452217, 2452218, 2452224, 2452225,
+ 2452231, 2452232, 2452238, 2452239, 2452245, 2452246, 2452252, 2452253, 2452259, 2452260, 2452266, 2452267,
+ 2452273, 2452274)
+ group by
+ ss_ticket_number,
+ ss_customer_sk,
+ ss_addr_sk,
+ ca_city
+ ) dn,
+ customer,
+ customer_address current_addr
+where
+ ss_customer_sk = c_customer_sk
+ and customer.c_current_addr_sk = current_addr.ca_address_sk
+ and current_addr.ca_city <> bought_city
+order by
+ c_last_name,
+ c_first_name,
+ ca_city,
+ bought_city,
+ ss_ticket_number
+limit 100
+-- end query 46 in stream 0 using template query46.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql
new file mode 100755
index 0000000000000..a510eefb13e17
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql
@@ -0,0 +1,27 @@
+-- start query 52 in stream 0 using template query52.tpl
+select
+ dt.d_year,
+ item.i_brand_id brand_id,
+ item.i_brand brand,
+ sum(ss_ext_sales_price) ext_price
+from
+ date_dim dt,
+ store_sales,
+ item
+where
+ dt.d_date_sk = store_sales.ss_sold_date_sk
+ and store_sales.ss_item_sk = item.i_item_sk
+ and item.i_manager_id = 1
+ and dt.d_moy = 12
+ and dt.d_year = 1998
+ and ss_sold_date_sk between 2451149 and 2451179 -- added for partition pruning
+group by
+ dt.d_year,
+ item.i_brand,
+ item.i_brand_id
+order by
+ dt.d_year,
+ ext_price desc,
+ brand_id
+limit 100
+-- end query 52 in stream 0 using template query52.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql
new file mode 100755
index 0000000000000..fb7bb75183858
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql
@@ -0,0 +1,37 @@
+-- start query 53 in stream 0 using template query53.tpl
+select
+ *
+from
+ (select
+ i_manufact_id,
+ sum(ss_sales_price) sum_sales,
+ avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales
+ from
+ item,
+ store_sales,
+ date_dim,
+ store
+ where
+ ss_item_sk = i_item_sk
+ and ss_sold_date_sk = d_date_sk
+ and ss_store_sk = s_store_sk
+ and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11)
+ and ((i_category in ('Books', 'Children', 'Electronics')
+ and i_class in ('personal', 'portable', 'reference', 'self-help')
+ and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9'))
+ or (i_category in ('Women', 'Music', 'Men')
+ and i_class in ('accessories', 'classical', 'fragrances', 'pants')
+ and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')))
+ and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter
+ group by
+ i_manufact_id,
+ d_qoy
+ ) tmp1
+where
+ case when avg_quarterly_sales > 0 then abs (sum_sales - avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1
+order by
+ avg_quarterly_sales,
+ sum_sales,
+ i_manufact_id
+limit 100
+-- end query 53 in stream 0 using template query53.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql
new file mode 100755
index 0000000000000..47b1f0292d901
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql
@@ -0,0 +1,24 @@
+-- start query 55 in stream 0 using template query55.tpl
+select
+ i_brand_id brand_id,
+ i_brand brand,
+ sum(ss_ext_sales_price) ext_price
+from
+ date_dim,
+ store_sales,
+ item
+where
+ d_date_sk = ss_sold_date_sk
+ and ss_item_sk = i_item_sk
+ and i_manager_id = 48
+ and d_moy = 11
+ and d_year = 2001
+ and ss_sold_date_sk between 2452215 and 2452244
+group by
+ i_brand,
+ i_brand_id
+order by
+ ext_price desc,
+ i_brand_id
+limit 100
+-- end query 55 in stream 0 using template query55.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql
new file mode 100755
index 0000000000000..3d5c4e9d64419
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql
@@ -0,0 +1,83 @@
+-- start query 59 in stream 0 using template query59.tpl
+with
+ wss as
+ (select
+ d_week_seq,
+ ss_store_sk,
+ sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales,
+ sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales,
+ sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales,
+ sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales,
+ sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales,
+ sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales,
+ sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales
+ from
+ store_sales,
+ date_dim
+ where
+ d_date_sk = ss_sold_date_sk
+ group by
+ d_week_seq,
+ ss_store_sk
+ )
+select
+ s_store_name1,
+ s_store_id1,
+ d_week_seq1,
+ sun_sales1 / sun_sales2,
+ mon_sales1 / mon_sales2,
+ tue_sales1 / tue_sales1,
+ wed_sales1 / wed_sales2,
+ thu_sales1 / thu_sales2,
+ fri_sales1 / fri_sales2,
+ sat_sales1 / sat_sales2
+from
+ (select
+ s_store_name s_store_name1,
+ wss.d_week_seq d_week_seq1,
+ s_store_id s_store_id1,
+ sun_sales sun_sales1,
+ mon_sales mon_sales1,
+ tue_sales tue_sales1,
+ wed_sales wed_sales1,
+ thu_sales thu_sales1,
+ fri_sales fri_sales1,
+ sat_sales sat_sales1
+ from
+ wss,
+ store,
+ date_dim d
+ where
+ d.d_week_seq = wss.d_week_seq
+ and ss_store_sk = s_store_sk
+ and d_month_seq between 1185 and 1185 + 11
+ ) y,
+ (select
+ s_store_name s_store_name2,
+ wss.d_week_seq d_week_seq2,
+ s_store_id s_store_id2,
+ sun_sales sun_sales2,
+ mon_sales mon_sales2,
+ tue_sales tue_sales2,
+ wed_sales wed_sales2,
+ thu_sales thu_sales2,
+ fri_sales fri_sales2,
+ sat_sales sat_sales2
+ from
+ wss,
+ store,
+ date_dim d
+ where
+ d.d_week_seq = wss.d_week_seq
+ and ss_store_sk = s_store_sk
+ and d_month_seq between 1185 + 12 and 1185 + 23
+ ) x
+where
+ s_store_id1 = s_store_id2
+ and d_week_seq1 = d_week_seq2 - 52
+order by
+ s_store_name1,
+ s_store_id1,
+ d_week_seq1
+limit 100
+-- end query 59 in stream 0 using template query59.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql
new file mode 100755
index 0000000000000..b71199ab17d0b
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql
@@ -0,0 +1,29 @@
+-- start query 63 in stream 0 using template query63.tpl
+select *
+from (select i_manager_id
+ ,sum(ss_sales_price) sum_sales
+ ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales
+ from item
+ ,store_sales
+ ,date_dim
+ ,store
+ where ss_item_sk = i_item_sk
+ and ss_sold_date_sk = d_date_sk
+ and ss_sold_date_sk between 2452123 and 2452487
+ and ss_store_sk = s_store_sk
+ and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11)
+ and (( i_category in ('Books','Children','Electronics')
+ and i_class in ('personal','portable','reference','self-help')
+ and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7',
+ 'exportiunivamalg #9','scholaramalgamalg #9'))
+ or( i_category in ('Women','Music','Men')
+ and i_class in ('accessories','classical','fragrances','pants')
+ and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1',
+ 'importoamalg #1')))
+group by i_manager_id, d_moy) tmp1
+where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1
+order by i_manager_id
+ ,avg_monthly_sales
+ ,sum_sales
+limit 100
+-- end query 63 in stream 0 using template query63.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql
new file mode 100755
index 0000000000000..7344feeff6a9f
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql
@@ -0,0 +1,58 @@
+-- start query 65 in stream 0 using template query65.tpl
+select
+ s_store_name,
+ i_item_desc,
+ sc.revenue,
+ i_current_price,
+ i_wholesale_cost,
+ i_brand
+from
+ store,
+ item,
+ (select
+ ss_store_sk,
+ avg(revenue) as ave
+ from
+ (select
+ ss_store_sk,
+ ss_item_sk,
+ sum(ss_sales_price) as revenue
+ from
+ store_sales,
+ date_dim
+ where
+ ss_sold_date_sk = d_date_sk
+ and d_month_seq between 1212 and 1212 + 11
+ and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter
+ group by
+ ss_store_sk,
+ ss_item_sk
+ ) sa
+ group by
+ ss_store_sk
+ ) sb,
+ (select
+ ss_store_sk,
+ ss_item_sk,
+ sum(ss_sales_price) as revenue
+ from
+ store_sales,
+ date_dim
+ where
+ ss_sold_date_sk = d_date_sk
+ and d_month_seq between 1212 and 1212 + 11
+ and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter
+ group by
+ ss_store_sk,
+ ss_item_sk
+ ) sc
+where
+ sb.ss_store_sk = sc.ss_store_sk
+ and sc.revenue <= 0.1 * sb.ave
+ and s_store_sk = sc.ss_store_sk
+ and i_item_sk = sc.ss_item_sk
+order by
+ s_store_name,
+ i_item_desc
+limit 100
+-- end query 65 in stream 0 using template query65.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql
new file mode 100755
index 0000000000000..94df4b3f57a90
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql
@@ -0,0 +1,62 @@
+-- start query 68 in stream 0 using template query68.tpl
+-- changed to match exact same partitions in original query
+select
+ c_last_name,
+ c_first_name,
+ ca_city,
+ bought_city,
+ ss_ticket_number,
+ extended_price,
+ extended_tax,
+ list_price
+from
+ (select
+ ss_ticket_number,
+ ss_customer_sk,
+ ca_city bought_city,
+ sum(ss_ext_sales_price) extended_price,
+ sum(ss_ext_list_price) list_price,
+ sum(ss_ext_tax) extended_tax
+ from
+ store_sales,
+ date_dim,
+ store,
+ household_demographics,
+ customer_address
+ where
+ store_sales.ss_sold_date_sk = date_dim.d_date_sk
+ and store_sales.ss_store_sk = store.s_store_sk
+ and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk
+ and store_sales.ss_addr_sk = customer_address.ca_address_sk
+ and date_dim.d_dom between 1 and 2
+ and (household_demographics.hd_dep_count = 5
+ or household_demographics.hd_vehicle_count = 3)
+ and date_dim.d_year in (1999, 1999 + 1, 1999 + 2)
+ and store.s_city in ('Midway', 'Fairview')
+ -- partition key filter
+ and ss_sold_date_sk in (2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, 2451331,
+ 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, 2451485,
+ 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, 2451666,
+ 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, 2451820,
+ 2451850, 2451851, 2451880, 2451881, 2451911, 2451912, 2451942, 2451943, 2451970, 2451971, 2452001,
+ 2452002, 2452031, 2452032, 2452062, 2452063, 2452092, 2452093, 2452123, 2452124, 2452154, 2452155,
+ 2452184, 2452185, 2452215, 2452216, 2452245, 2452246)
+ --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months)
+ --and d_date between '1999-01-01' and '1999-03-31'
+ group by
+ ss_ticket_number,
+ ss_customer_sk,
+ ss_addr_sk,
+ ca_city
+ ) dn,
+ customer,
+ customer_address current_addr
+where
+ ss_customer_sk = c_customer_sk
+ and customer.c_current_addr_sk = current_addr.ca_address_sk
+ and current_addr.ca_city <> bought_city
+order by
+ c_last_name,
+ ss_ticket_number
+limit 100
+-- end query 68 in stream 0 using template query68.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql
new file mode 100755
index 0000000000000..c61a2d0d2a8fa
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql
@@ -0,0 +1,31 @@
+-- start query 7 in stream 0 using template query7.tpl
+select
+ i_item_id,
+ avg(ss_quantity) agg1,
+ avg(ss_list_price) agg2,
+ avg(ss_coupon_amt) agg3,
+ avg(ss_sales_price) agg4
+from
+ store_sales,
+ customer_demographics,
+ date_dim,
+ item,
+ promotion
+where
+ ss_sold_date_sk = d_date_sk
+ and ss_item_sk = i_item_sk
+ and ss_cdemo_sk = cd_demo_sk
+ and ss_promo_sk = p_promo_sk
+ and cd_gender = 'F'
+ and cd_marital_status = 'W'
+ and cd_education_status = 'Primary'
+ and (p_channel_email = 'N'
+ or p_channel_event = 'N')
+ and d_year = 1998
+ and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter
+group by
+ i_item_id
+order by
+ i_item_id
+limit 100
+-- end query 7 in stream 0 using template query7.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql
new file mode 100755
index 0000000000000..8703910b305a8
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql
@@ -0,0 +1,49 @@
+-- start query 73 in stream 0 using template query73.tpl
+select
+ c_last_name,
+ c_first_name,
+ c_salutation,
+ c_preferred_cust_flag,
+ ss_ticket_number,
+ cnt
+from
+ (select
+ ss_ticket_number,
+ ss_customer_sk,
+ count(*) cnt
+ from
+ store_sales,
+ date_dim,
+ store,
+ household_demographics
+ where
+ store_sales.ss_sold_date_sk = date_dim.d_date_sk
+ and store_sales.ss_store_sk = store.s_store_sk
+ and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk
+ and date_dim.d_dom between 1 and 2
+ and (household_demographics.hd_buy_potential = '>10000'
+ or household_demographics.hd_buy_potential = 'Unknown')
+ and household_demographics.hd_vehicle_count > 0
+ and case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end > 1
+ and date_dim.d_year in (1998, 1998 + 1, 1998 + 2)
+ and store.s_county in ('Fairfield County','Ziebach County','Bronx County','Barrow County')
+ -- partition key filter
+ and ss_sold_date_sk in (2450815, 2450816, 2450846, 2450847, 2450874, 2450875, 2450905, 2450906, 2450935, 2450936, 2450966, 2450967,
+ 2450996, 2450997, 2451027, 2451028, 2451058, 2451059, 2451088, 2451089, 2451119, 2451120, 2451149,
+ 2451150, 2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301,
+ 2451331, 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484,
+ 2451485, 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637,
+ 2451666, 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819,
+ 2451820, 2451850, 2451851, 2451880, 2451881)
+ --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months)
+ group by
+ ss_ticket_number,
+ ss_customer_sk
+ ) dj,
+ customer
+where
+ ss_customer_sk = c_customer_sk
+ and cnt between 1 and 5
+order by
+ cnt desc
+-- end query 73 in stream 0 using template query73.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql
new file mode 100755
index 0000000000000..4254310ecd10b
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql
@@ -0,0 +1,59 @@
+-- start query 79 in stream 0 using template query79.tpl
+select
+ c_last_name,
+ c_first_name,
+ substr(s_city, 1, 30),
+ ss_ticket_number,
+ amt,
+ profit
+from
+ (select
+ ss_ticket_number,
+ ss_customer_sk,
+ store.s_city,
+ sum(ss_coupon_amt) amt,
+ sum(ss_net_profit) profit
+ from
+ store_sales,
+ date_dim,
+ store,
+ household_demographics
+ where
+ store_sales.ss_sold_date_sk = date_dim.d_date_sk
+ and store_sales.ss_store_sk = store.s_store_sk
+ and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk
+ and (household_demographics.hd_dep_count = 8
+ or household_demographics.hd_vehicle_count > 0)
+ and date_dim.d_dow = 1
+ and date_dim.d_year in (1998, 1998 + 1, 1998 + 2)
+ and store.s_number_employees between 200 and 295
+ and ss_sold_date_sk between 2450819 and 2451904
+ -- partition key filter
+ --and ss_sold_date_sk in (2450819, 2450826, 2450833, 2450840, 2450847, 2450854, 2450861, 2450868, 2450875, 2450882, 2450889,
+ -- 2450896, 2450903, 2450910, 2450917, 2450924, 2450931, 2450938, 2450945, 2450952, 2450959, 2450966, 2450973, 2450980, 2450987,
+ -- 2450994, 2451001, 2451008, 2451015, 2451022, 2451029, 2451036, 2451043, 2451050, 2451057, 2451064, 2451071, 2451078, 2451085,
+ -- 2451092, 2451099, 2451106, 2451113, 2451120, 2451127, 2451134, 2451141, 2451148, 2451155, 2451162, 2451169, 2451176, 2451183,
+ -- 2451190, 2451197, 2451204, 2451211, 2451218, 2451225, 2451232, 2451239, 2451246, 2451253, 2451260, 2451267, 2451274, 2451281,
+ -- 2451288, 2451295, 2451302, 2451309, 2451316, 2451323, 2451330, 2451337, 2451344, 2451351, 2451358, 2451365, 2451372, 2451379,
+ -- 2451386, 2451393, 2451400, 2451407, 2451414, 2451421, 2451428, 2451435, 2451442, 2451449, 2451456, 2451463, 2451470, 2451477,
+ -- 2451484, 2451491, 2451498, 2451505, 2451512, 2451519, 2451526, 2451533, 2451540, 2451547, 2451554, 2451561, 2451568, 2451575,
+ -- 2451582, 2451589, 2451596, 2451603, 2451610, 2451617, 2451624, 2451631, 2451638, 2451645, 2451652, 2451659, 2451666, 2451673,
+ -- 2451680, 2451687, 2451694, 2451701, 2451708, 2451715, 2451722, 2451729, 2451736, 2451743, 2451750, 2451757, 2451764, 2451771,
+ -- 2451778, 2451785, 2451792, 2451799, 2451806, 2451813, 2451820, 2451827, 2451834, 2451841, 2451848, 2451855, 2451862, 2451869,
+ -- 2451876, 2451883, 2451890, 2451897, 2451904)
+ group by
+ ss_ticket_number,
+ ss_customer_sk,
+ ss_addr_sk,
+ store.s_city
+ ) ms,
+ customer
+where
+ ss_customer_sk = c_customer_sk
+order by
+ c_last_name,
+ c_first_name,
+ substr(s_city, 1, 30),
+ profit
+ limit 100
+-- end query 79 in stream 0 using template query79.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql
new file mode 100755
index 0000000000000..b1d814af5e57a
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql
@@ -0,0 +1,43 @@
+-- start query 89 in stream 0 using template query89.tpl
+select
+ *
+from
+ (select
+ i_category,
+ i_class,
+ i_brand,
+ s_store_name,
+ s_company_name,
+ d_moy,
+ sum(ss_sales_price) sum_sales,
+ avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales
+ from
+ item,
+ store_sales,
+ date_dim,
+ store
+ where
+ ss_item_sk = i_item_sk
+ and ss_sold_date_sk = d_date_sk
+ and ss_store_sk = s_store_sk
+ and d_year in (2000)
+ and ((i_category in ('Home', 'Books', 'Electronics')
+ and i_class in ('wallpaper', 'parenting', 'musical'))
+ or (i_category in ('Shoes', 'Jewelry', 'Men')
+ and i_class in ('womens', 'birdal', 'pants')))
+ and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter
+ group by
+ i_category,
+ i_class,
+ i_brand,
+ s_store_name,
+ s_company_name,
+ d_moy
+ ) tmp1
+where
+ case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1
+order by
+ sum_sales - avg_monthly_sales,
+ s_store_name
+limit 100
+-- end query 89 in stream 0 using template query89.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql
new file mode 100755
index 0000000000000..f53f2f5f9c5b6
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql
@@ -0,0 +1,32 @@
+-- start query 98 in stream 0 using template query98.tpl
+select
+ i_item_desc,
+ i_category,
+ i_class,
+ i_current_price,
+ sum(ss_ext_sales_price) as itemrevenue,
+ sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) over (partition by i_class) as revenueratio
+from
+ store_sales,
+ item,
+ date_dim
+where
+ ss_item_sk = i_item_sk
+ and i_category in ('Jewelry', 'Sports', 'Books')
+ and ss_sold_date_sk = d_date_sk
+ and ss_sold_date_sk between 2451911 and 2451941 -- partition key filter (1 calendar month)
+ and d_date between '2001-01-01' and '2001-01-31'
+group by
+ i_item_id,
+ i_item_desc,
+ i_category,
+ i_class,
+ i_current_price
+order by
+ i_category,
+ i_class,
+ i_item_id,
+ i_item_desc,
+ revenueratio
+--limit 1000; -- added limit
+-- end query 98 in stream 0 using template query98.tpl
diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql
new file mode 100755
index 0000000000000..bf58b4bb3c5a5
--- /dev/null
+++ b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql
@@ -0,0 +1,14 @@
+select
+ count(*) as total,
+ count(ss_sold_date_sk) as not_null_total,
+ count(distinct ss_sold_date_sk) as unique_days,
+ max(ss_sold_date_sk) as max_ss_sold_date_sk,
+ max(ss_sold_time_sk) as max_ss_sold_time_sk,
+ max(ss_item_sk) as max_ss_item_sk,
+ max(ss_customer_sk) as max_ss_customer_sk,
+ max(ss_cdemo_sk) as max_ss_cdemo_sk,
+ max(ss_hdemo_sk) as max_ss_hdemo_sk,
+ max(ss_addr_sk) as max_ss_addr_sk,
+ max(ss_store_sk) as max_ss_store_sk,
+ max(ss_promo_sk) as max_ss_promo_sk
+from store_sales
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala
index 3e85d95523125..7e61a68025158 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfter
-class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
+import org.apache.spark.SparkConf
- protected override def beforeAll(): Unit = {
- sparkConf.set("spark.sql.codegen.fallback", "false")
- sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
- super.beforeAll()
- }
+class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
+ override protected def sparkConf: SparkConf = super.sparkConf
+ .set("spark.sql.codegen.fallback", "false")
+ .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
// adding some checking after each test is run, assuring that the configs are not changed
// in test code
@@ -38,12 +37,9 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo
}
class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
-
- protected override def beforeAll(): Unit = {
- sparkConf.set("spark.sql.codegen.fallback", "false")
- sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
- super.beforeAll()
- }
+ override protected def sparkConf: SparkConf = super.sparkConf
+ .set("spark.sql.codegen.fallback", "false")
+ .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
// adding some checking after each test is run, assuring that the configs are not changed
// in test code
@@ -55,15 +51,14 @@ class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeA
}
}
-class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with
-BeforeAndAfter {
+class TwoLevelAggregateHashMapWithVectorizedMapSuite
+ extends DataFrameAggregateSuite
+ with BeforeAndAfter {
- protected override def beforeAll(): Unit = {
- sparkConf.set("spark.sql.codegen.fallback", "false")
- sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
- sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
- super.beforeAll()
- }
+ override protected def sparkConf: SparkConf = super.sparkConf
+ .set("spark.sql.codegen.fallback", "false")
+ .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
// adding some checking after each test is run, assuring that the configs are not changed
// in test code
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index e66fe97afad45..3ad526873f5d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -647,7 +647,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
withTable("t") {
withTempPath { path =>
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
- sql(s"CREATE TABLE t USING parquet LOCATION '$path'")
+ sql(s"CREATE TABLE t USING parquet LOCATION '${path.toURI}'")
spark.catalog.cacheTable("t")
spark.table("t").select($"i").cache()
checkAnswer(spark.table("t").select($"i"), Row(1))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index b0f398dab7455..bc708ca88d7e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
}
+ private lazy val nullData = Seq(
+ (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b")
+
test("column names with space") {
val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
@@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("<=>") {
- checkAnswer(
- testData2.filter($"a" === 1),
- testData2.collect().toSeq.filter(r => r.getInt(0) == 1))
-
- checkAnswer(
- testData2.filter($"a" === $"b"),
- testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1)))
- }
-
- test("=!=") {
- val nullData = spark.createDataFrame(sparkContext.parallelize(
- Row(1, 1) ::
- Row(1, 2) ::
- Row(1, null) ::
- Row(null, null) :: Nil),
- StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType))))
-
checkAnswer(
nullData.filter($"b" <=> 1),
Row(1, 1) :: Nil)
@@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
checkAnswer(
nullData2.filter($"a" <=> null),
Row(null) :: Nil)
+ }
+ test("=!=") {
+ checkAnswer(
+ nullData.filter($"b" =!= 1),
+ Row(1, 2) :: Nil)
+
+ checkAnswer(nullData.filter($"b" =!= null), Nil)
+
+ checkAnswer(
+ nullData.filter($"a" =!= $"b"),
+ Row(1, 2) :: Nil)
}
test(">") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index e7079120bb7df..87aabf7220246 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.apache.spark.sql.execution.aggregate.{ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -186,6 +188,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") {
+ checkAnswer(
+ courseSales.cube("course", "year")
+ .agg(grouping("CouRse"), grouping("year")),
+ Row("Java", 2012, 0, 0) ::
+ Row("Java", 2013, 0, 0) ::
+ Row("Java", null, 0, 1) ::
+ Row("dotNET", 2012, 0, 0) ::
+ Row("dotNET", 2013, 0, 0) ::
+ Row("dotNET", null, 0, 1) ::
+ Row(null, 2012, 1, 0) ::
+ Row(null, 2013, 1, 0) ::
+ Row(null, null, 1, 1) :: Nil
+ )
+ }
+
test("rollup overlapping columns") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
@@ -538,4 +556,56 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0))
)
}
+
+ test("aggregate function in GROUP BY") {
+ val e = intercept[AnalysisException] {
+ testData.groupBy(sum($"key")).count()
+ }
+ assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
+ }
+
+ test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") {
+ checkAnswer(
+ testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")),
+ Seq(Row(3, 4, 6, 7, 9)))
+ checkAnswer(
+ testData2.groupBy(lit(3), lit(4)).agg(lit(6), 'b, sum("b")),
+ Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)))
+
+ checkAnswer(
+ spark.sql("SELECT 3, 4, SUM(b) FROM testData2 GROUP BY 1, 2"),
+ Seq(Row(3, 4, 9)))
+ checkAnswer(
+ spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"),
+ Seq(Row(3, 4, 9)))
+ }
+
+ test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
+ withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
+ .repartition(col("a"))
+
+ val objHashAggDF = df
+ .withColumn("d", expr("(a, b, c)"))
+ .groupBy("a", "b").agg(collect_list("d").as("e"))
+ .withColumn("f", expr("(b, e)"))
+ .groupBy("a").agg(collect_list("f").as("g"))
+ val aggPlan = objHashAggDF.queryExecution.executedPlan
+
+ val sortAggPlans = aggPlan.collect {
+ case sortAgg: SortAggregateExec => sortAgg
+ }
+ assert(sortAggPlans.isEmpty)
+
+ val objHashAggPlans = aggPlan.collect {
+ case objHashAgg: ObjectHashAggregateExec => objHashAgg
+ }
+ assert(objHashAggPlans.nonEmpty)
+
+ val exchangePlans = aggPlan.collect {
+ case shuffle: ShuffleExchange => shuffle
+ }
+ assert(exchangePlans.length == 1)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala
new file mode 100644
index 0000000000000..60f6f23860ed9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFrameHintSuite extends PlanTest with SharedSQLContext {
+ import testImplicits._
+ lazy val df = spark.range(10)
+
+ private def check(df: Dataset[_], expected: LogicalPlan) = {
+ comparePlans(
+ df.queryExecution.logical,
+ expected
+ )
+ }
+
+ test("various hint parameters") {
+ check(
+ df.hint("hint1"),
+ UnresolvedHint("hint1", Seq(),
+ df.logicalPlan
+ )
+ )
+
+ check(
+ df.hint("hint1", 1, "a"),
+ UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan)
+ )
+
+ check(
+ df.hint("hint1", 1, $"a"),
+ UnresolvedHint("hint1", Seq(1, $"a"),
+ df.logicalPlan
+ )
+ )
+
+ check(
+ df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")),
+ UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")),
+ df.logicalPlan
+ )
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 541ffb58e727f..aef0d7f3e425b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)
}
- test("broadcast join hint") {
+ test("broadcast join hint using broadcast function") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
@@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
}
}
+ test("broadcast join hint using Dataset.hint") {
+ // make sure a giant join is not broadcastable
+ val plan1 =
+ spark.range(10e10.toLong)
+ .join(spark.range(10e10.toLong), "id")
+ .queryExecution.executedPlan
+ assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0)
+
+ // now with a hint it should be broadcasted
+ val plan2 =
+ spark.range(10e10.toLong)
+ .join(spark.range(10e10.toLong).hint("broadcast"), "id")
+ .queryExecution.executedPlan
+ assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1)
+ }
+
test("join - outer join conversion") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b")
@@ -248,4 +264,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
val ab = a.join(b, Seq("a"), "fullouter")
checkAnswer(ab.join(c, "a"), Row(3, null, 4, 1) :: Nil)
}
+
+ test("SPARK-17685: WholeStageCodegenExec throws IndexOutOfBoundsException") {
+ val df = Seq((1, 1, "1"), (2, 2, "3")).toDF("int", "int2", "str")
+ val df2 = Seq((1, 1, "1"), (2, 3, "5")).toDF("int", "int2", "str")
+ val limit = 1310721
+ val innerJoin = df.limit(limit).join(df2.limit(limit), Seq("int", "int2"), "inner")
+ .agg(count($"int"))
+ checkAnswer(innerJoin, Row(1) :: Nil)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index 5e323c02b253d..45afbd29d1907 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -185,6 +185,23 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
}
}
}
+
+ test("SPARK-20430 Initialize Range parameters in a driver side") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil)
+ }
+ }
+
+ test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") {
+ val start = java.lang.Long.MAX_VALUE - 3
+ val end = java.lang.Long.MIN_VALUE + 2
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) {
+ assert(spark.range(start, end, 1).collect.length == 0)
+ assert(spark.range(start, start, 1).collect.length == 0)
+ }
+ }
+ }
}
object DataFrameRangeSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 97890a035a62f..dd118f88e3bb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
test("randomSplit on reordered partitions") {
- // This test ensures that randomSplit does not create overlapping splits even when the
- // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
- // rows in each partition.
- val data =
- sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
- val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
- assert(splits.length == 2, "wrong number of splits")
+ def testNonOverlappingSplits(data: DataFrame): Unit = {
+ val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
+ assert(splits.length == 2, "wrong number of splits")
+
+ // Verify that the splits span the entire dataset
+ assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
- // Verify that the splits span the entire dataset
- assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
+ // Verify that the splits don't overlap
+ assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)
- // Verify that the splits don't overlap
- assert(splits(0).intersect(splits(1)).collect().isEmpty)
+ // Verify that the results are deterministic across multiple runs
+ val firstRun = splits.toSeq.map(_.collect().toSeq)
+ val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
+ assert(firstRun == secondRun)
+ }
- // Verify that the results are deterministic across multiple runs
- val firstRun = splits.toSeq.map(_.collect().toSeq)
- val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
- assert(firstRun == secondRun)
+ // This test ensures that randomSplit does not create overlapping splits even when the
+ // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
+ // rows in each partition.
+ val dataWithInts = sparkContext.parallelize(1 to 600, 2)
+ .mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
+ val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
+ .map(i => (i, Map(i -> i.toString)))
+ .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
+ val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
+ .map(i => (i, Array(Map(i -> i.toString))))
+ .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")
+
+ testNonOverlappingSplits(dataWithInts)
+ testNonOverlappingSplits(dataWithMaps)
+ testNonOverlappingSplits(dataWithArrayOfMaps)
}
test("pearson correlation") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 52bd4e19f8952..7450a1a35b8f6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1722,4 +1722,71 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
"Cannot have map type columns in DataFrame which calls set operations"))
}
}
+
+ test("SPARK-20359: catalyst outer join optimization should not throw npe") {
+ val df1 = Seq("a", "b", "c").toDF("x")
+ .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x"))
+ val df2 = Seq("a", "b").toDF("x1")
+ df1
+ .join(df2, df1("x") === df2("x1"), "left_outer")
+ .filter($"x1".isNotNull || !$"y".isin("a!"))
+ .count
+ }
+
+ // The fix of SPARK-21720 avoid an exception regarding JVM code size limit
+ // TODO: When we make a threshold of splitting statements (1024) configurable,
+ // we will re-enable this with max threshold to cause an exception
+ // See https://github.com/apache/spark/pull/18972/files#r150223463
+ ignore("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") {
+ val N = 400
+ val rows = Seq(Row.fromSeq(Seq.fill(N)("string")))
+ val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType)))
+ val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)
+
+ val filter = (0 until N)
+ .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
+ df.filter(filter).count
+ }
+
+ test("SPARK-20897: cached self-join should not fail") {
+ // force to plan sort merge join
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ val df = Seq(1 -> "a").toDF("i", "j")
+ val df1 = df.as("t1")
+ val df2 = df.as("t2")
+ assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1)
+ }
+ }
+
+ test("order-by ordinal.") {
+ checkAnswer(
+ testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
+ Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
+ }
+
+ test("SPARK-22252: FileFormatWriter should respect the input query schema") {
+ withTable("t1", "t2", "t3", "t4") {
+ spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
+ spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
+ checkAnswer(spark.table("t2"), Row(0, 0))
+
+ // Test picking part of the columns when writing.
+ spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")
+ spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4")
+ checkAnswer(spark.table("t4"), Row(0, 0))
+ }
+ }
+
+ test("SPARK-22271: mean overflows and returns null for some decimal variables") {
+ val d = 0.034567890
+ val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol")
+ val result = df.select('DecimalCol cast DecimalType(38, 33))
+ .select(col("DecimalCol")).describe()
+ val mean = result.select("DecimalCol").where($"summary" === "mean")
+ assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
+ }
+
+ test("SPARK-22469: compare string with decimal") {
+ checkAnswer(Seq("1.5").toDF("s").filter("s > 0.5"), Row("1.5"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 1255c49104718..204858fa29787 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types._
/**
* Window function testing for DataFrame API.
@@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
df.select(selectList: _*).where($"value" < 2),
Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0)))
}
+
+ test("SPARK-21258: complex object in combination with spilling") {
+ // Make sure we trigger the spilling path.
+ withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") {
+ val sampleSchema = new StructType().
+ add("f0", StringType).
+ add("f1", LongType).
+ add("f2", ArrayType(new StructType().
+ add("f20", StringType))).
+ add("f3", ArrayType(new StructType().
+ add("f30", StringType)))
+
+ val w0 = Window.partitionBy("f0").orderBy("f1")
+ val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue)
+
+ val c0 = first(struct($"f2", $"f3")).over(w0) as "c0"
+ val c1 = last(struct($"f2", $"f3")).over(w1) as "c1"
+
+ val input =
+ """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]}
+ |{"f1":1497802179638}
+ |{"f1":1497802189347}
+ |{"f1":1497802189593}
+ |{"f1":1497802189597}
+ |{"f1":1497802189599}
+ |{"f1":1497802192103}
+ |{"f1":1497802193414}
+ |{"f1":1497802193577}
+ |{"f1":1497802193709}
+ |{"f1":1497802202883}
+ |{"f1":1497802203006}
+ |{"f1":1497802203743}
+ |{"f1":1497802203834}
+ |{"f1":1497802203887}
+ |{"f1":1497802203893}
+ |{"f1":1497802203976}
+ |{"f1":1497820168098}
+ |""".stripMargin.split("\n").toSeq
+
+ import testImplicits._
+
+ spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 541565344f758..212ee1b39adf1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -32,6 +32,9 @@ case class QueueClass(q: Queue[Int])
case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
+case class InnerData(name: String, value: Int)
+case class NestedData(id: Int, param: Map[String, InnerData])
+
package object packageobject {
case class PackageClass(value: Int)
}
@@ -258,9 +261,19 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}
+ test("nested sequences") {
+ checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
+ checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
+ }
+
test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
}
+ test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") {
+ val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100))))
+ val ds = spark.createDataset(data)
+ checkDataset(ds, data: _*)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala
index 92c5656f65bb4..68f7de047b392 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.esotericsoftware.kryo.io.{Input, Output}
+import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.test.TestSparkSession
/**
* Test suite to test Kryo custom registrators.
@@ -30,12 +30,10 @@ import org.apache.spark.sql.test.TestSparkSession
class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
- /**
- * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]].
- */
- protected override def beforeAll(): Unit = {
- sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName)
- super.beforeAll()
+
+ override protected def sparkConf: SparkConf = {
+ // Make sure we use the KryoRegistrator
+ super.sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName)
}
test("Kryo registrator") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 5b5cd28ad0c99..683fe4a329365 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,12 +20,15 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
+import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -320,6 +323,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
((("b", 2), ("b", 2)), ("b", 2)))
}
+ test("joinWith join types") {
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ val e1 = intercept[AnalysisException] {
+ ds1.joinWith(ds2, $"a.value" === $"b.value", "left_semi")
+ }.getMessage
+ assert(e1.contains("Invalid join type in joinWith: " + LeftSemi.sql))
+
+ val e2 = intercept[AnalysisException] {
+ ds1.joinWith(ds2, $"a.value" === $"b.value", "left_anti")
+ }.getMessage
+ assert(e2.contains("Invalid join type in joinWith: " + LeftAnti.sql))
+ }
+
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupByKey(v => (1, v._2))
@@ -1168,6 +1186,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS()
checkDataset(ds, WithMapInOption(Some(Map(1 -> 1))))
}
+
+ test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") {
+ withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") {
+ val data = Seq("\u0020\u0021\u0023", "abc")
+ val df = data.toDF()
+ val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'")
+ val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$"))
+ val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'")
+ checkAnswer(rlike1, rlike2)
+ assert(rlike3.count() == 0)
+ }
+ }
+
+ test("SPARK-21538: Attribute resolution inconsistency in Dataset API") {
+ val df = spark.range(3).withColumnRenamed("id", "x")
+ val expected = Row(0) :: Row(1) :: Row (2) :: Nil
+ checkAnswer(df.sort("id"), expected)
+ checkAnswer(df.sort(col("id")), expected)
+ checkAnswer(df.sort($"id"), expected)
+ checkAnswer(df.sort('id), expected)
+ checkAnswer(df.orderBy("id"), expected)
+ checkAnswer(df.orderBy(col("id")), expected)
+ checkAnswer(df.orderBy($"id"), expected)
+ checkAnswer(df.orderBy('id), expected)
+ }
+
+ test("SPARK-22472: add null check for top-level primitive values") {
+ // If the primitive values are from Option, we need to do runtime null check.
+ val ds = Seq(Some(1), None).toDS().as[Int]
+ intercept[NullPointerException](ds.collect())
+ val e = intercept[SparkException](ds.map(_ * 2).collect())
+ assert(e.getCause.isInstanceOf[NullPointerException])
+
+ withTempPath { path =>
+ Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath)
+ // If the primitive values are from files, we need to do runtime null check.
+ val ds = spark.read.parquet(path.getCanonicalPath).as[Int]
+ intercept[NullPointerException](ds.collect())
+ val e = intercept[SparkException](ds.map(_ * 2).collect())
+ assert(e.getCause.isInstanceOf[NullPointerException])
+ }
+ }
+
+ test("SPARK-22442: Generate correct field names for special characters") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val data = """{"field.1": 1, "field 2": 2}"""
+ Seq(data).toDF().repartition(1).write.text(path)
+ val ds = spark.read.json(path).as[SpecialCharClass]
+ checkDataset(ds, SpecialCharClass("1", "2"))
+ }
+ }
}
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
@@ -1253,3 +1323,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA)
case class CircularReferenceClassC(ar: Array[CircularReferenceClassC])
case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD])
+
+case class SpecialCharClass(`field.1`: String, `field 2`: String)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index cef5bbf0e85a7..b9871afd59e4f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
checkAnswer(
df.select(explode_outer('intList)),
- Row(1) :: Row(2) :: Row(3) :: Nil)
+ Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
}
test("single posexplode") {
@@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
checkAnswer(
df.select(posexplode_outer('intList)),
- Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
+ Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil)
}
test("explode and other columns") {
@@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.select(explode_outer('intList).as('int)).select('int),
- Row(1) :: Row(2) :: Row(3) :: Nil)
+ Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
checkAnswer(
df.select(explode('intList).as('int)).select(sum('int)),
@@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.select(explode_outer('map)),
- Row("a", "b") :: Row("c", "d") :: Nil)
+ Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil)
}
test("explode on map with aliases") {
@@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
- Row("a", "b") :: Nil)
+ Row("a", "b") :: Row(null, null) :: Nil)
}
test("self join explode") {
@@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
)
checkAnswer(
df2.selectExpr("inline_outer(col1)"),
- Row(3, "4") :: Row(5, "6") :: Nil
+ Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 1a66aa85f5a02..cdfd33dfb91a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.language.existentials
@@ -25,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.StructType
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
class JoinSuite extends QueryTest with SharedSQLContext {
@@ -198,6 +200,14 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Nil)
}
+ test("SPARK-22141: Propagate empty relation before checking Cartesian products") {
+ Seq("inner", "left", "right", "left_outer", "right_outer", "full_outer").foreach { joinType =>
+ val x = testData2.where($"a" === 2 && !($"a" === 2)).as("x")
+ val y = testData2.where($"a" === 1 && !($"a" === 1)).as("y")
+ checkAnswer(x.join(y, Seq.empty, joinType), Nil)
+ }
+ }
+
test("big inner join, 4 matches per row") {
val bigData = testData.union(testData).union(testData).union(testData)
val bigDataX = bigData.as("x")
@@ -665,7 +675,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("test SortMergeJoin (with spill)") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
- "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
+ "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0",
+ "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") {
assertSpilled(sparkContext, "inner join") {
checkAnswer(
@@ -738,4 +749,22 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("outer broadcast hash join should not throw NPE") {
+ withTempView("v1", "v2") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
+ Seq(2 -> 2).toDF("x", "y").createTempView("v1")
+
+ spark.createDataFrame(
+ Seq(Row(1, "a")).asJava,
+ new StructType().add("i", "int", nullable = false).add("j", "string", nullable = false)
+ ).createTempView("v2")
+
+ checkAnswer(
+ sql("select x, y, i, j from v1 left join v2 on x = i and y < length(j)"),
+ Row(2, 2, null, null)
+ )
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 8465e8d036a6d..989f8c23a4069 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.sql.functions.{from_json, struct, to_json}
+import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -188,15 +188,33 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row("""{"_1":"26/08/2015 18:00"}""") :: Nil)
}
- test("to_json unsupported type") {
+ test("to_json - key types of map don't matter") {
+ // interval type is invalid for converting to JSON. However, the keys of a map are treated
+ // as strings, so its type doesn't matter.
val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a")
- .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c"))
+ .select(struct(map($"a._1".cast(CalendarIntervalType), lit("a")).as("col1")).as("c"))
+ checkAnswer(
+ df.select(to_json($"c")),
+ Row("""{"col1":{"interval -3 months 7 hours":"a"}}""") :: Nil)
+ }
+
+ test("to_json unsupported type") {
+ val baseDf = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a")
+ val df = baseDf.select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c"))
val e = intercept[AnalysisException]{
// Unsupported type throws an exception
df.select(to_json($"c")).collect()
}
assert(e.getMessage.contains(
"Unable to convert column a of type calendarinterval to JSON."))
+
+ // interval type is invalid for converting to JSON. We can't use it as value type of a map.
+ val df2 = baseDf
+ .select(struct(map(lit("a"), $"a._1".cast(CalendarIntervalType)).as("col1")).as("c"))
+ val e2 = intercept[AnalysisException] {
+ df2.select(to_json($"c")).collect()
+ }
+ assert(e2.getMessage.contains("Unable to convert column col1 of type calendarinterval to JSON"))
}
test("roundtrip in to_json and from_json - struct") {
@@ -274,7 +292,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
val errMsg2 = intercept[AnalysisException] {
df3.selectExpr("""from_json(value, 'time InvalidType')""")
}
- assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported"))
+ assert(errMsg2.getMessage.contains("DataType invalidtype is not supported"))
val errMsg3 = intercept[AnalysisException] {
df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 328c5395ec91e..5be8c581e9ddb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
+
+ val bdPi: BigDecimal = BigDecimal(31415925L, 7)
+ checkAnswer(
+ sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " +
+ s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"),
+ Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null))
+ )
+
+ checkAnswer(
+ sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " +
+ s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"),
+ Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
+ )
}
test("round/bround with data frame from a local Seq of Product") {
@@ -245,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("round/bround with table columns") {
+ withTable("t") {
+ Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t")
+ checkAnswer(
+ sql("select i, round(i) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+ checkAnswer(
+ sql("select i, bround(i) from t"),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
+ }
+ }
+
test("exp") {
testOneToOneMathFunction(exp, math.exp)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
index 52c200796ce41..623a1b6f854cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
@@ -22,20 +22,22 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.streaming.ProcessingTime
+import org.apache.spark.sql.streaming.{ProcessingTime, Trigger}
class ProcessingTimeSuite extends SparkFunSuite {
test("create") {
- assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
- assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
- assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
- assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)
-
- intercept[IllegalArgumentException] { ProcessingTime(null: String) }
- intercept[IllegalArgumentException] { ProcessingTime("") }
- intercept[IllegalArgumentException] { ProcessingTime("invalid") }
- intercept[IllegalArgumentException] { ProcessingTime("1 month") }
- intercept[IllegalArgumentException] { ProcessingTime("1 year") }
+ def getIntervalMs(trigger: Trigger): Long = trigger.asInstanceOf[ProcessingTime].intervalMs
+
+ assert(getIntervalMs(Trigger.ProcessingTime(10.seconds)) === 10 * 1000)
+ assert(getIntervalMs(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) === 10 * 1000)
+ assert(getIntervalMs(Trigger.ProcessingTime("1 minute")) === 60 * 1000)
+ assert(getIntervalMs(Trigger.ProcessingTime("interval 1 minute")) === 60 * 1000)
+
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime(null: String) }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("") }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("invalid") }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 month") }
+ intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 year") }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0dd9296a3f0ff..d2b17a3b7b994 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.io.File
import java.math.MathContext
+import java.net.{MalformedURLException, URL}
import java.sql.Timestamp
import java.util.concurrent.atomic.AtomicBoolean
@@ -1636,6 +1637,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
+ test("SPARK-23281: verify the correctness of sort direction on composite order by clause") {
+ withTempView("src") {
+ Seq[(Integer, Integer)](
+ (1, 1),
+ (1, 3),
+ (2, 3),
+ (3, 3),
+ (4, null),
+ (5, null)
+ ).toDF("key", "value").createOrReplaceTempView("src")
+
+ checkAnswer(sql(
+ """
+ |SELECT MAX(value) as value, key as col2
+ |FROM src
+ |GROUP BY key
+ |ORDER BY value desc, key
+ """.stripMargin),
+ Seq(Row(3, 1), Row(3, 2), Row(3, 3), Row(null, 4), Row(null, 5)))
+
+ checkAnswer(sql(
+ """
+ |SELECT MAX(value) as value, key as col2
+ |FROM src
+ |GROUP BY key
+ |ORDER BY value desc, key desc
+ """.stripMargin),
+ Seq(Row(3, 3), Row(3, 2), Row(3, 1), Row(null, 5), Row(null, 4)))
+
+ checkAnswer(sql(
+ """
+ |SELECT MAX(value) as value, key as col2
+ |FROM src
+ |GROUP BY key
+ |ORDER BY value asc, key desc
+ """.stripMargin),
+ Seq(Row(null, 5), Row(null, 4), Row(3, 3), Row(3, 2), Row(3, 1)))
+ }
+ }
+
test("run sql directly on files") {
val df = spark.range(100).toDF()
withTempPath(f => {
@@ -2606,4 +2647,67 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage)
}
}
+
+ test("SPARK-12868: Allow adding jars from hdfs ") {
+ val jarFromHdfs = "hdfs://doesnotmatter/test.jar"
+ val jarFromInvalidFs = "fffs://doesnotmatter/test.jar"
+
+ // if 'hdfs' is not supported, MalformedURLException will be thrown
+ new URL(jarFromHdfs)
+
+ intercept[MalformedURLException] {
+ new URL(jarFromInvalidFs)
+ }
+ }
+
+ test("RuntimeReplaceable functions should not take extra parameters") {
+ val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
+ assert(e.message.contains("Invalid number of arguments"))
+ }
+
+ test("SPARK-21228: InSet incorrect handling of structs") {
+ withTempView("A") {
+ // reduce this from the default of 10 so the repro query text is not too long
+ withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
+ // a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
+ spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
+ .createOrReplaceTempView("A")
+ val df = sql(
+ """
+ |SELECT * from
+ | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
+ | -- the IN will become InSet with a Set of GenericInternalRows
+ | -- a GenericInternalRow is never equal to an UnsafeRow so the query would
+ | -- returns 0 results, which is incorrect
+ | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
+ | NAMED_STRUCT('a', 3L, 'b', 3L))
+ """.stripMargin)
+ checkAnswer(df, Row(Row(1, 1)))
+ }
+ }
+ }
+
+ test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {
+ withTempPath { path =>
+ Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j")
+ .write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath)
+ withTable("t") {
+ sql(s"create table t using parquet options(path='${path.getCanonicalPath}')")
+ // We should respect the column order in data schema.
+ assert(spark.table("t").columns === Array("i", "p", "j"))
+ checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil)
+ // The DESC TABLE should report same schema as table scan.
+ assert(sql("desc t").select("col_name")
+ .as[String].collect().mkString(",").contains("i,p,j"))
+ }
+ }
+ }
+
+ test("SPARK-25144 'distinct' causes memory leak") {
+ val ds = List(Foo(Some("bar"))).toDS
+ val result = ds.flatMap(_.bar).distinct
+ result.rdd.isEmpty
+ }
}
+
+case class Foo(bar: Option[String])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
index 386d13d07a95f..1c6afa5e26e14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala
@@ -17,49 +17,48 @@
package org.apache.spark.sql
+import org.scalatest.BeforeAndAfterEach
+
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
/**
* Test cases for the builder pattern of [[SparkSession]].
*/
-class SparkSessionBuilderSuite extends SparkFunSuite {
+class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach {
- private var initialSession: SparkSession = _
+ override def afterEach(): Unit = {
+ // This suite should not interfere with the other test suites.
+ SparkSession.getActiveSession.foreach(_.stop())
+ SparkSession.clearActiveSession()
+ SparkSession.getDefaultSession.foreach(_.stop())
+ SparkSession.clearDefaultSession()
+ }
- private lazy val sparkContext: SparkContext = {
- initialSession = SparkSession.builder()
+ test("create with config options and propagate them to SparkContext and SparkSession") {
+ val session = SparkSession.builder()
.master("local")
.config("spark.ui.enabled", value = false)
.config("some-config", "v2")
.getOrCreate()
- initialSession.sparkContext
- }
-
- test("create with config options and propagate them to SparkContext and SparkSession") {
- // Creating a new session with config - this works by just calling the lazy val
- sparkContext
- assert(initialSession.sparkContext.conf.get("some-config") == "v2")
- assert(initialSession.conf.get("some-config") == "v2")
- SparkSession.clearDefaultSession()
+ assert(session.sparkContext.conf.get("some-config") == "v2")
+ assert(session.conf.get("some-config") == "v2")
}
test("use global default session") {
- val session = SparkSession.builder().getOrCreate()
+ val session = SparkSession.builder().master("local").getOrCreate()
assert(SparkSession.builder().getOrCreate() == session)
- SparkSession.clearDefaultSession()
}
test("config options are propagated to existing SparkSession") {
- val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate()
+ val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate()
assert(session1.conf.get("spark-config1") == "a")
val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate()
assert(session1 == session2)
assert(session1.conf.get("spark-config1") == "b")
- SparkSession.clearDefaultSession()
}
test("use session from active thread session and propagate config options") {
- val defaultSession = SparkSession.builder().getOrCreate()
+ val defaultSession = SparkSession.builder().master("local").getOrCreate()
val activeSession = defaultSession.newSession()
SparkSession.setActiveSession(activeSession)
val session = SparkSession.builder().config("spark-config2", "a").getOrCreate()
@@ -70,16 +69,14 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
SparkSession.clearActiveSession()
assert(SparkSession.builder().getOrCreate() == defaultSession)
- SparkSession.clearDefaultSession()
}
test("create a new session if the default session has been stopped") {
- val defaultSession = SparkSession.builder().getOrCreate()
+ val defaultSession = SparkSession.builder().master("local").getOrCreate()
SparkSession.setDefaultSession(defaultSession)
defaultSession.stop()
val newSession = SparkSession.builder().master("local").getOrCreate()
assert(newSession != defaultSession)
- newSession.stop()
}
test("create a new session if the active thread session has been stopped") {
@@ -88,11 +85,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
activeSession.stop()
val newSession = SparkSession.builder().master("local").getOrCreate()
assert(newSession != activeSession)
- newSession.stop()
}
test("create SparkContext first then SparkSession") {
- sparkContext.stop()
val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1")
val sparkContext2 = new SparkContext(conf)
val session = SparkSession.builder().config("key2", "value2").getOrCreate()
@@ -101,14 +96,12 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
assert(session.sparkContext.conf.get("key1") == "value1")
assert(session.sparkContext.conf.get("key2") == "value2")
assert(session.sparkContext.conf.get("spark.app.name") == "test")
- session.stop()
}
test("SPARK-15887: hive-site.xml should be loaded") {
val session = SparkSession.builder().master("local").getOrCreate()
assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true")
assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true")
- session.stop()
}
test("SPARK-15991: Set global Hadoop conf") {
@@ -120,7 +113,6 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue)
} finally {
session.sparkContext.hadoopConfiguration.unset(mySpecialKey)
- session.stop()
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
new file mode 100644
index 0000000000000..43db79663322a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Test cases for the [[SparkSessionExtensions]].
+ */
+class SparkSessionExtensionSuite extends SparkFunSuite {
+ type ExtensionsBuilder = SparkSessionExtensions => Unit
+ private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder
+
+ private def stop(spark: SparkSession): Unit = {
+ spark.stop()
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ }
+
+ private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = {
+ val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate()
+ try f(spark) finally {
+ stop(spark)
+ }
+ }
+
+ test("inject analyzer rule") {
+ withSession(_.injectResolutionRule(MyRule)) { session =>
+ assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+ }
+ }
+
+ test("inject check analysis rule") {
+ withSession(_.injectCheckRule(MyCheckRule)) { session =>
+ assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
+ }
+ }
+
+ test("inject optimizer rule") {
+ withSession(_.injectOptimizerRule(MyRule)) { session =>
+ assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
+ }
+ }
+
+ test("inject spark planner strategy") {
+ withSession(_.injectPlannerStrategy(MySparkStrategy)) { session =>
+ assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
+ }
+ }
+
+ test("inject parser") {
+ val extension = create { extensions =>
+ extensions.injectParser((_, _) => CatalystSqlParser)
+ }
+ withSession(extension) { session =>
+ assert(session.sessionState.sqlParser == CatalystSqlParser)
+ }
+ }
+
+ test("inject stacked parsers") {
+ val extension = create { extensions =>
+ extensions.injectParser((_, _) => CatalystSqlParser)
+ extensions.injectParser(MyParser)
+ extensions.injectParser(MyParser)
+ }
+ withSession(extension) { session =>
+ val parser = MyParser(session, MyParser(session, CatalystSqlParser))
+ assert(session.sessionState.sqlParser == parser)
+ }
+ }
+
+ test("use custom class for extensions") {
+ val session = SparkSession.builder()
+ .master("local[1]")
+ .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName)
+ .getOrCreate()
+ try {
+ assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
+ assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+ } finally {
+ stop(session)
+ }
+ }
+}
+
+case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan
+}
+
+case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) {
+ override def apply(plan: LogicalPlan): Unit = { }
+}
+
+case class MySparkStrategy(spark: SparkSession) extends SparkStrategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
+}
+
+case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface {
+ override def parsePlan(sqlText: String): LogicalPlan =
+ delegate.parsePlan(sqlText)
+
+ override def parseExpression(sqlText: String): Expression =
+ delegate.parseExpression(sqlText)
+
+ override def parseTableIdentifier(sqlText: String): TableIdentifier =
+ delegate.parseTableIdentifier(sqlText)
+
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
+ delegate.parseFunctionIdentifier(sqlText)
+
+ override def parseTableSchema(sqlText: String): StructType =
+ delegate.parseTableSchema(sqlText)
+
+ override def parseDataType(sqlText: String): DataType =
+ delegate.parseDataType(sqlText)
+}
+
+class MyExtensions extends (SparkSessionExtensions => Unit) {
+ def apply(e: SparkSessionExtensions): Unit = {
+ e.injectPlannerStrategy(MySparkStrategy)
+ e.injectResolutionRule(MyRule)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index ddc393c8da053..86d19af9dd548 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import scala.util.Random
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
+import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -40,17 +40,6 @@ import org.apache.spark.sql.types._
class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext {
import testImplicits._
- private def checkTableStats(tableName: String, expectedRowCount: Option[Int])
- : Option[CatalogStatistics] = {
- val df = spark.table(tableName)
- val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation =>
- assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount)
- rel.catalogTable.get.stats
- }
- assert(stats.size == 1)
- stats.head
- }
-
test("estimates the size of a limit 0 on outer join") {
withTempView("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
@@ -88,6 +77,19 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
}
}
+ test("analyze empty table") {
+ val table = "emptyTable"
+ withTable(table) {
+ sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET")
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan")
+ val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None)
+ assert(fetchedStats1.get.sizeInBytes == 0)
+ sql(s"ANALYZE TABLE $table COMPUTE STATISTICS")
+ val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0))
+ assert(fetchedStats2.get.sizeInBytes == 0)
+ }
+ }
+
test("test table-level statistics for data source table") {
val tableName = "tbl"
withTable(tableName) {
@@ -96,11 +98,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
// noscan won't count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
- checkTableStats(tableName, expectedRowCount = None)
+ checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = None)
// without noscan, we count the number of rows
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
- checkTableStats(tableName, expectedRowCount = Some(2))
+ checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = Some(2))
}
}
@@ -164,7 +166,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
numbers.foreach { case (input, (expectedSize, expectedRows)) =>
val stats = Statistics(sizeInBytes = input, rowCount = Some(input))
val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," +
- s" isBroadcastable=${stats.isBroadcastable}"
+ s" hints=none"
assert(stats.simpleString == expectedString)
}
}
@@ -219,6 +221,23 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
private val randomName = new Random(31)
+ def checkTableStats(
+ tableName: String,
+ hasSizeInBytes: Boolean,
+ expectedRowCounts: Option[Int]): Option[CatalogStatistics] = {
+ val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats
+
+ if (hasSizeInBytes || expectedRowCounts.nonEmpty) {
+ assert(stats.isDefined)
+ assert(stats.get.sizeInBytes >= 0)
+ assert(stats.get.rowCount === expectedRowCounts)
+ } else {
+ assert(stats.isEmpty)
+ }
+
+ stats
+ }
+
/**
* Compute column stats for the given DataFrame and compare it with colStats.
*/
@@ -285,7 +304,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
// Analyze only one column.
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1")
val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect {
- case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta)
+ case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta)
case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get)
}.head
val emptyColStat = ColumnStat(0, None, None, 0, 4, 4)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 0f0199cbe2777..2a3bdfbfa0108 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -72,7 +72,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
}
}
- test("rdd deserialization does not crash [SPARK-15791]") {
+ test("SPARK-15791: rdd deserialization does not crash") {
sql("select (select 1 as b) as b").rdd.count()
}
@@ -854,4 +854,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"),
Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil)
}
+
+ test("SPARK-20688: correctly check analysis for scalar sub-queries") {
+ withTempView("t") {
+ Seq(1 -> "a").toDF("i", "j").createTempView("t")
+ val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)"))
+ assert(e.message.contains("cannot resolve '`a`' given input columns: [i, j]"))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
new file mode 100644
index 0000000000000..e47d4b0ee25d4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.catalyst.util.resourceToString
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+/**
+ * This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized
+ * without hitting the max iteration threshold.
+ */
+class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll {
+
+ // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting
+ // the max iteration of analyzer/optimizer batches.
+ assert(Utils.isTesting, "spark.testing is not set to true")
+
+ /**
+ * Drop all the tables
+ */
+ protected override def afterAll(): Unit = {
+ try {
+ spark.sessionState.catalog.reset()
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ override def beforeAll() {
+ super.beforeAll()
+ sql(
+ """
+ |CREATE TABLE `catalog_page` (
+ |`cp_catalog_page_sk` INT, `cp_catalog_page_id` STRING, `cp_start_date_sk` INT,
+ |`cp_end_date_sk` INT, `cp_department` STRING, `cp_catalog_number` INT,
+ |`cp_catalog_page_number` INT, `cp_description` STRING, `cp_type` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `catalog_returns` (
+ |`cr_returned_date_sk` INT, `cr_returned_time_sk` INT, `cr_item_sk` INT,
+ |`cr_refunded_customer_sk` INT, `cr_refunded_cdemo_sk` INT, `cr_refunded_hdemo_sk` INT,
+ |`cr_refunded_addr_sk` INT, `cr_returning_customer_sk` INT, `cr_returning_cdemo_sk` INT,
+ |`cr_returning_hdemo_sk` INT, `cr_returning_addr_sk` INT, `cr_call_center_sk` INT,
+ |`cr_catalog_page_sk` INT, `cr_ship_mode_sk` INT, `cr_warehouse_sk` INT, `cr_reason_sk` INT,
+ |`cr_order_number` INT, `cr_return_quantity` INT, `cr_return_amount` DECIMAL(7,2),
+ |`cr_return_tax` DECIMAL(7,2), `cr_return_amt_inc_tax` DECIMAL(7,2), `cr_fee` DECIMAL(7,2),
+ |`cr_return_ship_cost` DECIMAL(7,2), `cr_refunded_cash` DECIMAL(7,2),
+ |`cr_reversed_charge` DECIMAL(7,2), `cr_store_credit` DECIMAL(7,2),
+ |`cr_net_loss` DECIMAL(7,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `customer` (
+ |`c_customer_sk` INT, `c_customer_id` STRING, `c_current_cdemo_sk` INT,
+ |`c_current_hdemo_sk` INT, `c_current_addr_sk` INT, `c_first_shipto_date_sk` INT,
+ |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING,
+ |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT,
+ |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING,
+ |`c_email_address` STRING, `c_last_review_date` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `customer_address` (
+ |`ca_address_sk` INT, `ca_address_id` STRING, `ca_street_number` STRING,
+ |`ca_street_name` STRING, `ca_street_type` STRING, `ca_suite_number` STRING,
+ |`ca_city` STRING, `ca_county` STRING, `ca_state` STRING, `ca_zip` STRING,
+ |`ca_country` STRING, `ca_gmt_offset` DECIMAL(5,2), `ca_location_type` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `customer_demographics` (
+ |`cd_demo_sk` INT, `cd_gender` STRING, `cd_marital_status` STRING,
+ |`cd_education_status` STRING, `cd_purchase_estimate` INT, `cd_credit_rating` STRING,
+ |`cd_dep_count` INT, `cd_dep_employed_count` INT, `cd_dep_college_count` INT)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `date_dim` (
+ |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING,
+ |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT,
+ |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT,
+ |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING,
+ |`d_weekend` STRING, `d_following_holiday` STRING, `d_first_dom` INT, `d_last_dom` INT,
+ |`d_same_day_ly` INT, `d_same_day_lq` INT, `d_current_day` STRING, `d_current_week` STRING,
+ |`d_current_month` STRING, `d_current_quarter` STRING, `d_current_year` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `household_demographics` (
+ |`hd_demo_sk` INT, `hd_income_band_sk` INT, `hd_buy_potential` STRING, `hd_dep_count` INT,
+ |`hd_vehicle_count` INT)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `inventory` (`inv_date_sk` INT, `inv_item_sk` INT, `inv_warehouse_sk` INT,
+ |`inv_quantity_on_hand` INT)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING,
+ |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2),
+ |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT,
+ |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT,
+ |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING,
+ |`i_units` STRING, `i_container` STRING, `i_manager_id` INT, `i_product_name` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `promotion` (
+ |`p_promo_sk` INT, `p_promo_id` STRING, `p_start_date_sk` INT, `p_end_date_sk` INT,
+ |`p_item_sk` INT, `p_cost` DECIMAL(15,2), `p_response_target` INT, `p_promo_name` STRING,
+ |`p_channel_dmail` STRING, `p_channel_email` STRING, `p_channel_catalog` STRING,
+ |`p_channel_tv` STRING, `p_channel_radio` STRING, `p_channel_press` STRING,
+ |`p_channel_event` STRING, `p_channel_demo` STRING, `p_channel_details` STRING,
+ |`p_purpose` STRING, `p_discount_active` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `store` (
+ |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING,
+ |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING,
+ |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING,
+ |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING,
+ |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING,
+ |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING,
+ |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING,
+ |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING,
+ |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `store_returns` (
+ |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT,
+ |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT,
+ |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT,
+ |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2),
+ |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2),
+ |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2),
+ |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2),
+ |`sr_net_loss` DECIMAL(7,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `catalog_sales` (
+ |`cs_sold_date_sk` INT, `cs_sold_time_sk` INT, `cs_ship_date_sk` INT,
+ |`cs_bill_customer_sk` INT, `cs_bill_cdemo_sk` INT, `cs_bill_hdemo_sk` INT,
+ |`cs_bill_addr_sk` INT, `cs_ship_customer_sk` INT, `cs_ship_cdemo_sk` INT,
+ |`cs_ship_hdemo_sk` INT, `cs_ship_addr_sk` INT, `cs_call_center_sk` INT,
+ |`cs_catalog_page_sk` INT, `cs_ship_mode_sk` INT, `cs_warehouse_sk` INT,
+ |`cs_item_sk` INT, `cs_promo_sk` INT, `cs_order_number` INT, `cs_quantity` INT,
+ |`cs_wholesale_cost` DECIMAL(7,2), `cs_list_price` DECIMAL(7,2),
+ |`cs_sales_price` DECIMAL(7,2), `cs_ext_discount_amt` DECIMAL(7,2),
+ |`cs_ext_sales_price` DECIMAL(7,2), `cs_ext_wholesale_cost` DECIMAL(7,2),
+ |`cs_ext_list_price` DECIMAL(7,2), `cs_ext_tax` DECIMAL(7,2), `cs_coupon_amt` DECIMAL(7,2),
+ |`cs_ext_ship_cost` DECIMAL(7,2), `cs_net_paid` DECIMAL(7,2),
+ |`cs_net_paid_inc_tax` DECIMAL(7,2), `cs_net_paid_inc_ship` DECIMAL(7,2),
+ |`cs_net_paid_inc_ship_tax` DECIMAL(7,2), `cs_net_profit` DECIMAL(7,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `web_sales` (
+ |`ws_sold_date_sk` INT, `ws_sold_time_sk` INT, `ws_ship_date_sk` INT, `ws_item_sk` INT,
+ |`ws_bill_customer_sk` INT, `ws_bill_cdemo_sk` INT, `ws_bill_hdemo_sk` INT,
+ |`ws_bill_addr_sk` INT, `ws_ship_customer_sk` INT, `ws_ship_cdemo_sk` INT,
+ |`ws_ship_hdemo_sk` INT, `ws_ship_addr_sk` INT, `ws_web_page_sk` INT, `ws_web_site_sk` INT,
+ |`ws_ship_mode_sk` INT, `ws_warehouse_sk` INT, `ws_promo_sk` INT, `ws_order_number` INT,
+ |`ws_quantity` INT, `ws_wholesale_cost` DECIMAL(7,2), `ws_list_price` DECIMAL(7,2),
+ |`ws_sales_price` DECIMAL(7,2), `ws_ext_discount_amt` DECIMAL(7,2),
+ |`ws_ext_sales_price` DECIMAL(7,2), `ws_ext_wholesale_cost` DECIMAL(7,2),
+ |`ws_ext_list_price` DECIMAL(7,2), `ws_ext_tax` DECIMAL(7,2),
+ |`ws_coupon_amt` DECIMAL(7,2), `ws_ext_ship_cost` DECIMAL(7,2), `ws_net_paid` DECIMAL(7,2),
+ |`ws_net_paid_inc_tax` DECIMAL(7,2), `ws_net_paid_inc_ship` DECIMAL(7,2),
+ |`ws_net_paid_inc_ship_tax` DECIMAL(7,2), `ws_net_profit` DECIMAL(7,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `store_sales` (
+ |`ss_sold_date_sk` INT, `ss_sold_time_sk` INT, `ss_item_sk` INT, `ss_customer_sk` INT,
+ |`ss_cdemo_sk` INT, `ss_hdemo_sk` INT, `ss_addr_sk` INT, `ss_store_sk` INT,
+ |`ss_promo_sk` INT, `ss_ticket_number` INT, `ss_quantity` INT,
+ |`ss_wholesale_cost` DECIMAL(7,2), `ss_list_price` DECIMAL(7,2),
+ |`ss_sales_price` DECIMAL(7,2), `ss_ext_discount_amt` DECIMAL(7,2),
+ |`ss_ext_sales_price` DECIMAL(7,2), `ss_ext_wholesale_cost` DECIMAL(7,2),
+ |`ss_ext_list_price` DECIMAL(7,2), `ss_ext_tax` DECIMAL(7,2),
+ |`ss_coupon_amt` DECIMAL(7,2), `ss_net_paid` DECIMAL(7,2),
+ |`ss_net_paid_inc_tax` DECIMAL(7,2), `ss_net_profit` DECIMAL(7,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `web_returns` (
+ |`wr_returned_date_sk` BIGINT, `wr_returned_time_sk` BIGINT, `wr_item_sk` BIGINT,
+ |`wr_refunded_customer_sk` BIGINT, `wr_refunded_cdemo_sk` BIGINT,
+ |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT,
+ |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT,
+ |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT,
+ |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT,
+ |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2),
+ |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2),
+ |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2),
+ |`wr_reversed_charge` DECIMAL(7,2), `wr_account_credit` DECIMAL(7,2),
+ |`wr_net_loss` DECIMAL(7,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `web_site` (
+ |`web_site_sk` INT, `web_site_id` STRING, `web_rec_start_date` DATE,
+ |`web_rec_end_date` DATE, `web_name` STRING, `web_open_date_sk` INT,
+ |`web_close_date_sk` INT, `web_class` STRING, `web_manager` STRING, `web_mkt_id` INT,
+ |`web_mkt_class` STRING, `web_mkt_desc` STRING, `web_market_manager` STRING,
+ |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING,
+ |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING,
+ |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING,
+ |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `reason` (
+ |`r_reason_sk` INT, `r_reason_id` STRING, `r_reason_desc` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `call_center` (
+ |`cc_call_center_sk` INT, `cc_call_center_id` STRING, `cc_rec_start_date` DATE,
+ |`cc_rec_end_date` DATE, `cc_closed_date_sk` INT, `cc_open_date_sk` INT, `cc_name` STRING,
+ |`cc_class` STRING, `cc_employees` INT, `cc_sq_ft` INT, `cc_hours` STRING,
+ |`cc_manager` STRING, `cc_mkt_id` INT, `cc_mkt_class` STRING, `cc_mkt_desc` STRING,
+ |`cc_market_manager` STRING, `cc_division` INT, `cc_division_name` STRING, `cc_company` INT,
+ |`cc_company_name` STRING, `cc_street_number` STRING, `cc_street_name` STRING,
+ |`cc_street_type` STRING, `cc_suite_number` STRING, `cc_city` STRING, `cc_county` STRING,
+ |`cc_state` STRING, `cc_zip` STRING, `cc_country` STRING, `cc_gmt_offset` DECIMAL(5,2),
+ |`cc_tax_percentage` DECIMAL(5,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `warehouse` (
+ |`w_warehouse_sk` INT, `w_warehouse_id` STRING, `w_warehouse_name` STRING,
+ |`w_warehouse_sq_ft` INT, `w_street_number` STRING, `w_street_name` STRING,
+ |`w_street_type` STRING, `w_suite_number` STRING, `w_city` STRING, `w_county` STRING,
+ |`w_state` STRING, `w_zip` STRING, `w_country` STRING, `w_gmt_offset` DECIMAL(5,2))
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `ship_mode` (
+ |`sm_ship_mode_sk` INT, `sm_ship_mode_id` STRING, `sm_type` STRING, `sm_code` STRING,
+ |`sm_carrier` STRING, `sm_contract` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `income_band` (
+ |`ib_income_band_sk` INT, `ib_lower_bound` INT, `ib_upper_bound` INT)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `time_dim` (
+ |`t_time_sk` INT, `t_time_id` STRING, `t_time` INT, `t_hour` INT, `t_minute` INT,
+ |`t_second` INT, `t_am_pm` STRING, `t_shift` STRING, `t_sub_shift` STRING,
+ |`t_meal_time` STRING)
+ |USING parquet
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE TABLE `web_page` (`wp_web_page_sk` INT, `wp_web_page_id` STRING,
+ |`wp_rec_start_date` DATE, `wp_rec_end_date` DATE, `wp_creation_date_sk` INT,
+ |`wp_access_date_sk` INT, `wp_autogen_flag` STRING, `wp_customer_sk` INT,
+ |`wp_url` STRING, `wp_type` STRING, `wp_char_count` INT, `wp_link_count` INT,
+ |`wp_image_count` INT, `wp_max_ad_count` INT)
+ |USING parquet
+ """.stripMargin)
+ }
+
+ val tpcdsQueries = Seq(
+ "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
+ "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20",
+ "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30",
+ "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40",
+ "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50",
+ "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60",
+ "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70",
+ "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80",
+ "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90",
+ "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99")
+
+ tpcdsQueries.foreach { name =>
+ val queryString = resourceToString(s"tpcds/$name.sql",
+ classLoader = Thread.currentThread().getContextClassLoader)
+ test(name) {
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ // Just check the plans can be properly generated
+ sql(queryString).queryExecution.executedPlan
+ }
+ }
+ }
+
+ // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries
+ val modifiedTPCDSQueries = Seq(
+ "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59",
+ "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max")
+
+ modifiedTPCDSQueries.foreach { name =>
+ val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql",
+ classLoader = Thread.currentThread().getContextClassLoader)
+ test(s"modified-$name") {
+ // Just check the plans can be properly generated
+ sql(queryString).queryExecution.executedPlan
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index ae6b2bc3753fb..6f8723af91cea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -93,6 +93,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
+ test("UDF defined using UserDefinedFunction") {
+ import functions.udf
+ val foo = udf((x: Int) => x + 1)
+ spark.udf.register("foo", foo)
+ assert(sql("select foo(5)").head().getInt(0) == 6)
+ }
+
test("ZeroArgument UDF") {
spark.udf.register("random0", () => { Math.random()})
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index a32763db054f3..a5f904c621e6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -101,9 +101,22 @@ class UnsafeRowSuite extends SparkFunSuite {
MemoryAllocator.UNSAFE.free(offheapRowPage)
}
}
+ val (bytesFromArrayBackedRowWithOffset, field0StringFromArrayBackedRowWithOffset) = {
+ val baos = new ByteArrayOutputStream()
+ val numBytes = arrayBackedUnsafeRow.getSizeInBytes
+ val bytesWithOffset = new Array[Byte](numBytes + 100)
+ System.arraycopy(arrayBackedUnsafeRow.getBaseObject.asInstanceOf[Array[Byte]], 0,
+ bytesWithOffset, 100, numBytes)
+ val arrayBackedRow = new UnsafeRow(arrayBackedUnsafeRow.numFields())
+ arrayBackedRow.pointTo(bytesWithOffset, Platform.BYTE_ARRAY_OFFSET + 100, numBytes)
+ arrayBackedRow.writeToStream(baos, null)
+ (baos.toByteArray, arrayBackedRow.getString(0))
+ }
assert(bytesFromArrayBackedRow === bytesFromOffheapRow)
assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow)
+ assert(bytesFromArrayBackedRow === bytesFromArrayBackedRowWithOffset)
+ assert(field0StringFromArrayBackedRow === field0StringFromArrayBackedRowWithOffset)
}
test("calling getDouble() and getFloat() on null columns") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
index 05a2b2c862c73..423e1288e8dcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala
@@ -18,22 +18,17 @@ package org.apache.spark.sql.execution
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.Utils
/**
* Suite that tests the redaction of DataSourceScanExec
*/
class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {
- import Utils._
-
- override def beforeAll(): Unit = {
- sparkConf.set("spark.redaction.string.regex",
- "file:/[\\w_]+")
- super.beforeAll()
- }
+ override protected def sparkConf: SparkConf = super.sparkConf
+ .set("spark.redaction.string.regex", "file:/[\\w_]+")
test("treeString is redacted") {
withTempDir { dir =>
@@ -43,7 +38,7 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {
val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
.asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head
- assert(rootPath.toString.contains(basePath.toString))
+ assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/")))
assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName))
assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
index 00c5f2550cbb1..a5adc3639ad64 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark {
benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
var sum = 0L
for (_ <- 0L until iterations) {
- val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ val array = new ExternalAppendOnlyUnsafeRowArray(
+ ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+ numSpillThreshold)
+
rows.foreach(x => array.add(x))
val iterator = array.generateIterator()
@@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark {
benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
var sum = 0L
for (_ <- 0L until iterations) {
- val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold)
rows.foreach(x => array.add(x))
val iterator = array.generateIterator()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
index 53c41639942b4..ecc7264d79442 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
override def afterAll(): Unit = TaskContext.unset()
- private def withExternalArray(spillThreshold: Int)
+ private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int)
(f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = {
sc = new SparkContext("local", "test", new SparkConf(false))
@@ -45,6 +45,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
taskContext,
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
+ inMemoryThreshold,
spillThreshold)
try f(array) finally {
array.clear()
@@ -109,9 +110,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
assert(getNumBytesSpilled > 0)
}
- test("insert rows less than the spillThreshold") {
- val spillThreshold = 100
- withExternalArray(spillThreshold) { array =>
+ test("insert rows less than the inMemoryThreshold") {
+ val (inMemoryThreshold, spillThreshold) = (100, 50)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
assert(array.isEmpty)
val expectedValues = populateRows(array, 1)
@@ -122,8 +123,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
// Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]])
// Verify that NO spill has happened
- populateRows(array, spillThreshold - 1, expectedValues)
- assert(array.length == spillThreshold)
+ populateRows(array, inMemoryThreshold - 1, expectedValues)
+ assert(array.length == inMemoryThreshold)
assertNoSpill()
val iterator2 = validateData(array, expectedValues)
@@ -133,20 +134,42 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
}
- test("insert rows more than the spillThreshold to force spill") {
- val spillThreshold = 100
- withExternalArray(spillThreshold) { array =>
- val numValuesInserted = 20 * spillThreshold
-
+ test("insert rows more than the inMemoryThreshold but less than spillThreshold") {
+ val (inMemoryThreshold, spillThreshold) = (10, 50)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
assert(array.isEmpty)
- val expectedValues = populateRows(array, 1)
- assert(array.length == 1)
+ val expectedValues = populateRows(array, inMemoryThreshold - 1)
+ assert(array.length == (inMemoryThreshold - 1))
+ val iterator1 = validateData(array, expectedValues)
+ assertNoSpill()
+
+ // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a
+ // spill to happen. Verify that NO spill has happened
+ populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues)
+ assert(array.length == spillThreshold - 1)
+ assertNoSpill()
+
+ val iterator2 = validateData(array, expectedValues)
+ assert(!iterator2.hasNext)
+ assert(!iterator1.hasNext)
+ intercept[ConcurrentModificationException](iterator1.next())
+ }
+ }
+
+ test("insert rows enough to force spill") {
+ val (inMemoryThreshold, spillThreshold) = (20, 10)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
+ assert(array.isEmpty)
+ val expectedValues = populateRows(array, inMemoryThreshold - 1)
+ assert(array.length == (inMemoryThreshold - 1))
val iterator1 = validateData(array, expectedValues)
+ assertNoSpill()
- // Populate more rows to trigger spill. Verify that spill has happened
- populateRows(array, numValuesInserted - 1, expectedValues)
- assert(array.length == numValuesInserted)
+ // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen.
+ // Verify that spill has happened
+ populateRows(array, 2, expectedValues)
+ assert(array.length == inMemoryThreshold + 1)
assertSpill()
val iterator2 = validateData(array, expectedValues)
@@ -158,7 +181,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("iterator on an empty array should be empty") {
- withExternalArray(spillThreshold = 10) { array =>
+ withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array =>
val iterator = array.generateIterator()
assert(array.isEmpty)
assert(array.length == 0)
@@ -167,7 +190,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with negative start index") {
- withExternalArray(spillThreshold = 2) { array =>
+ withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array =>
val exception =
intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10))
@@ -178,8 +201,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with start index exceeding array's size (without spill)") {
- val spillThreshold = 2
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
populateRows(array, spillThreshold / 2)
val exception =
@@ -191,8 +214,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with start index exceeding array's size (with spill)") {
- val spillThreshold = 2
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
populateRows(array, spillThreshold * 2)
val exception =
@@ -205,10 +228,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with custom start index (without spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
- val expectedValues = populateRows(array, spillThreshold)
- val startIndex = spillThreshold / 2
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
+ val expectedValues = populateRows(array, inMemoryThreshold)
+ val startIndex = inMemoryThreshold / 2
val iterator = array.generateIterator(startIndex = startIndex)
for (i <- startIndex until expectedValues.length) {
checkIfValueExists(iterator, expectedValues(i))
@@ -217,8 +240,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("generate iterator with custom start index (with spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (20, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
val expectedValues = populateRows(array, spillThreshold * 10)
val startIndex = spillThreshold * 2
val iterator = array.generateIterator(startIndex = startIndex)
@@ -229,7 +252,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("test iterator invalidation (without spill)") {
- withExternalArray(spillThreshold = 10) { array =>
+ withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array =>
// insert 2 rows, iterate until the first row
populateRows(array, 2)
@@ -254,9 +277,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("test iterator invalidation (with spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
- // Populate enough rows so that spill has happens
+ val (inMemoryThreshold, spillThreshold) = (2, 10)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
+ // Populate enough rows so that spill happens
populateRows(array, spillThreshold * 2)
assertSpill()
@@ -281,7 +304,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("clear on an empty the array") {
- withExternalArray(spillThreshold = 2) { array =>
+ withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array =>
val iterator = array.generateIterator()
assert(!iterator.hasNext)
@@ -299,10 +322,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
}
test("clear array (without spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (10, 100)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
// Populate rows ... but not enough to trigger spill
- populateRows(array, spillThreshold / 2)
+ populateRows(array, inMemoryThreshold / 2)
assertNoSpill()
// Clear the array
@@ -311,21 +334,21 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar
// Re-populate few rows so that there is no spill
// Verify the data. Verify that there was no spill
- val expectedValues = populateRows(array, spillThreshold / 3)
+ val expectedValues = populateRows(array, inMemoryThreshold / 2)
validateData(array, expectedValues)
assertNoSpill()
// Populate more rows .. enough to not trigger a spill.
// Verify the data. Verify that there was no spill
- populateRows(array, spillThreshold / 3, expectedValues)
+ populateRows(array, inMemoryThreshold / 2, expectedValues)
validateData(array, expectedValues)
assertNoSpill()
}
}
test("clear array (with spill)") {
- val spillThreshold = 10
- withExternalArray(spillThreshold) { array =>
+ val (inMemoryThreshold, spillThreshold) = (10, 20)
+ withExternalArray(inMemoryThreshold, spillThreshold) { array =>
// Populate enough rows to trigger spill
populateRows(array, spillThreshold * 2)
val bytesSpilled = getNumBytesSpilled
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
index 5c63c6a414f93..cc943e0356f2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
@@ -35,39 +35,53 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext {
private var globalTempDB: String = _
test("basic semantic") {
- sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'")
-
- // If there is no database in table name, we should try local temp view first, if not found,
- // try table/view in current database, which is "default" in this case. So we expect
- // NoSuchTableException here.
- intercept[NoSuchTableException](spark.table("src"))
-
- // Use qualified name to refer to the global temp view explicitly.
- checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a"))
-
- // Table name without database will never refer to a global temp view.
- intercept[NoSuchTableException](sql("DROP VIEW src"))
-
- sql(s"DROP VIEW $globalTempDB.src")
- // The global temp view should be dropped successfully.
- intercept[NoSuchTableException](spark.table(s"$globalTempDB.src"))
-
- // We can also use Dataset API to create global temp view
- Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src")
- checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a"))
-
- // Use qualified name to rename a global temp view.
- sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2")
- intercept[NoSuchTableException](spark.table(s"$globalTempDB.src"))
- checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a"))
-
- // Use qualified name to alter a global temp view.
- sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'")
- checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b"))
-
- // We can also use Catalog API to drop global temp view
- spark.catalog.dropGlobalTempView("src2")
- intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2"))
+ val expectedErrorMsg = "not found"
+ try {
+ sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'")
+
+ // If there is no database in table name, we should try local temp view first, if not found,
+ // try table/view in current database, which is "default" in this case. So we expect
+ // NoSuchTableException here.
+ var e = intercept[AnalysisException](spark.table("src")).getMessage
+ assert(e.contains(expectedErrorMsg))
+
+ // Use qualified name to refer to the global temp view explicitly.
+ checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a"))
+
+ // Table name without database will never refer to a global temp view.
+ e = intercept[AnalysisException](sql("DROP VIEW src")).getMessage
+ assert(e.contains(expectedErrorMsg))
+
+ sql(s"DROP VIEW $globalTempDB.src")
+ // The global temp view should be dropped successfully.
+ e = intercept[AnalysisException](spark.table(s"$globalTempDB.src")).getMessage
+ assert(e.contains(expectedErrorMsg))
+
+ // We can also use Dataset API to create global temp view
+ Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src")
+ checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a"))
+
+ // Use qualified name to rename a global temp view.
+ sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2")
+ e = intercept[AnalysisException](spark.table(s"$globalTempDB.src")).getMessage
+ assert(e.contains(expectedErrorMsg))
+ checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a"))
+
+ // Use qualified name to alter a global temp view.
+ sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'")
+ checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b"))
+
+ // We can also use Catalog API to drop global temp view
+ spark.catalog.dropGlobalTempView("src2")
+ e = intercept[AnalysisException](spark.table(s"$globalTempDB.src2")).getMessage
+ assert(e.contains(expectedErrorMsg))
+
+ // We can also use Dataset API to replace global temp view
+ Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src")
+ checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b"))
+ } finally {
+ spark.catalog.dropGlobalTempView("src")
+ }
}
test("global temp view is shared among all sessions") {
@@ -106,7 +120,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext {
test("CREATE TABLE LIKE should work for global temp view") {
try {
sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b")
- sql(s"CREATE TABLE cloned LIKE ${globalTempDB}.src")
+ sql(s"CREATE TABLE cloned LIKE $globalTempDB.src")
val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned"))
assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false))
} finally {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
index 58c310596ca6d..6c66902127d03 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
@@ -117,4 +117,12 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext {
"select partcol1, max(partcol2) from srcpart where partcol1 = 0 group by rollup (partcol1)",
"select partcol2 from (select partcol2 from srcpart where partcol1 = 0 union all " +
"select partcol2 from srcpart where partcol1 = 1) t group by partcol2")
+
+ test("SPARK-21884 Fix StackOverflowError on MetadataOnlyQuery") {
+ withTable("t_1000") {
+ sql("CREATE TABLE t_1000 (a INT, p INT) USING PARQUET PARTITIONED BY (p)")
+ (1 to 1000).foreach(p => sql(s"ALTER TABLE t_1000 ADD PARTITION (p=$p)"))
+ sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect()
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 4d155d538d637..63e17c7f372b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
@@ -513,26 +513,30 @@ class PlannerSuite extends SharedSQLContext {
}
test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") {
- val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
- // Both left and right keys should be sorted after the SMJ.
- Seq(orderingA, orderingB).foreach { ordering =>
- assertSortRequirementsAreSatisfied(
- childPlan = innerSmj,
- requiredOrdering = Seq(ordering),
- shouldHaveSort = false)
+ Seq(Inner, Cross).foreach { joinType =>
+ val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB)
+ // Both left and right keys should be sorted after the SMJ.
+ Seq(orderingA, orderingB).foreach { ordering =>
+ assertSortRequirementsAreSatisfied(
+ childPlan = innerSmj,
+ requiredOrdering = Seq(ordering),
+ shouldHaveSort = false)
+ }
}
}
test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " +
"child SMJ") {
- val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
- val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC)
- // After the second SMJ, exprA, exprB and exprC should all be sorted.
- Seq(orderingA, orderingB, orderingC).foreach { ordering =>
- assertSortRequirementsAreSatisfied(
- childPlan = parentSmj,
- requiredOrdering = Seq(ordering),
- shouldHaveSort = false)
+ Seq(Inner, Cross).foreach { joinType =>
+ val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB)
+ val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, None, childSmj, planC)
+ // After the second SMJ, exprA, exprB and exprC should all be sorted.
+ Seq(orderingA, orderingB, orderingC).foreach { ordering =>
+ assertSortRequirementsAreSatisfied(
+ childPlan = parentSmj,
+ requiredOrdering = Seq(ordering),
+ shouldHaveSort = false)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index 1c1931b6a6daf..afccbe5cc6d19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -16,37 +16,36 @@
*/
package org.apache.spark.sql.execution
-import java.util.Locale
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext
class QueryExecutionSuite extends SharedSQLContext {
test("toString() exception/error handling") {
- val badRule = new SparkStrategy {
- var mode: String = ""
- override def apply(plan: LogicalPlan): Seq[SparkPlan] =
- mode.toLowerCase(Locale.ROOT) match {
- case "exception" => throw new AnalysisException(mode)
- case "error" => throw new Error(mode)
- case _ => Nil
- }
- }
- spark.experimental.extraStrategies = badRule :: Nil
+ spark.experimental.extraStrategies = Seq(
+ new SparkStrategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil
+ })
def qe: QueryExecution = new QueryExecution(spark, OneRowRelation)
// Nothing!
- badRule.mode = ""
assert(qe.toString.contains("OneRowRelation"))
// Throw an AnalysisException - this should be captured.
- badRule.mode = "exception"
+ spark.experimental.extraStrategies = Seq(
+ new SparkStrategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] =
+ throw new AnalysisException("exception")
+ })
assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))
// Throw an Error - this should not be captured.
- badRule.mode = "error"
+ spark.experimental.extraStrategies = Seq(
+ new SparkStrategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] =
+ throw new Error("error")
+ })
val error = intercept[Error](qe.toString)
assert(error.getMessage.contains("error"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
index d32716c18ddfb..08a4a21b20f61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
@@ -669,4 +669,29 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
"positive."))
}
}
+
+ test("permanent view should be case-preserving") {
+ withView("v") {
+ sql("CREATE VIEW v AS SELECT 1 as aBc")
+ assert(spark.table("v").schema.head.name == "aBc")
+
+ sql("CREATE OR REPLACE VIEW v AS SELECT 2 as cBa")
+ assert(spark.table("v").schema.head.name == "cBa")
+ }
+ }
+
+ test("sparkSession API view resolution with different default database") {
+ withDatabase("db2") {
+ withView("v1") {
+ withTable("t1") {
+ sql("USE default")
+ sql("CREATE TABLE t1 USING parquet AS SELECT 1 AS c0")
+ sql("CREATE VIEW v1 AS SELECT * FROM t1")
+ sql("CREATE DATABASE IF NOT EXISTS db2")
+ sql("USE db2")
+ checkAnswer(spark.table("default.v1"), Row(1))
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index 52e4f047225de..a57514c256b90 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -356,6 +356,46 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
spark.catalog.dropTempView("nums")
}
+ test("window function: mutiple window expressions specified by range in a single expression") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.createOrReplaceTempView("nums")
+ withTempView("nums") {
+ val expected =
+ Row(1, 1, 1, 4, null, 8, 25) ::
+ Row(1, 3, 4, 9, 1, 12, 24) ::
+ Row(1, 5, 9, 15, 4, 16, 21) ::
+ Row(1, 7, 16, 21, 8, 9, 16) ::
+ Row(1, 9, 25, 16, 12, null, 9) ::
+ Row(0, 2, 2, 6, null, 10, 30) ::
+ Row(0, 4, 6, 12, 2, 14, 28) ::
+ Row(0, 6, 12, 18, 6, 18, 24) ::
+ Row(0, 8, 20, 24, 10, 10, 18) ::
+ Row(0, 10, 30, 18, 14, null, 10) ::
+ Nil
+
+ val actual = sql(
+ """
+ |SELECT
+ | y,
+ | x,
+ | sum(x) over w1 as history_sum,
+ | sum(x) over w2 as period_sum1,
+ | sum(x) over w3 as period_sum2,
+ | sum(x) over w4 as period_sum3,
+ | sum(x) over w5 as future_sum
+ |FROM nums
+ |WINDOW
+ | w1 AS (PARTITION BY y ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW),
+ | w2 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING),
+ | w3 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING ),
+ | w4 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 4 FOLLOWING),
+ | w5 AS (PARTITION BY y ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
+ """.stripMargin
+ )
+ checkAnswer(actual, expected)
+ }
+ }
+
test("SPARK-7595: Window will cause resolve failed with self join") {
checkAnswer(sql(
"""
@@ -437,7 +477,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
|WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
""".stripMargin)
- withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") {
+ withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1",
+ "spark.sql.windowExec.buffer.spill.threshold" -> "2") {
assertSpilled(sparkContext, "test with low buffer spill threshold") {
checkAnswer(actual, expected)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala
new file mode 100644
index 0000000000000..aaf51b5b90111
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * Tests for the sameResult function for [[SparkPlan]]s.
+ */
+class SameResultSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("FileSourceScanExec: different orders of data filters and partition filters") {
+ withTempPath { path =>
+ val tmpDir = path.getCanonicalPath
+ spark.range(10)
+ .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
+ .write
+ .partitionBy("a", "b")
+ .parquet(tmpDir)
+ val df = spark.read.parquet(tmpDir)
+ // partition filters: a > 1 AND b < 9
+ // data filters: c > 1 AND d < 9
+ val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9"))
+ val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1"))
+ assert(plan1.sameResult(plan2))
+ }
+ }
+
+ private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = {
+ df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
+ .asInstanceOf[FileSourceScanExec]
+ }
+
+ test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
+ val df1 = spark.range(10).agg(sum($"id"))
+ val df2 = spark.range(10).agg(sum($"id"))
+ assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
+
+ val df3 = spark.range(10).agg(sumDistinct($"id"))
+ val df4 = spark.range(10).agg(sumDistinct($"id"))
+ assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 6cf18de0cc768..3b77657762517 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable
import scala.util.{Random, Try}
import scala.util.control.NonFatal
+import org.mockito.Mockito._
import org.scalatest.Matchers
import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl}
@@ -53,6 +54,8 @@ class UnsafeFixedWidthAggregationMapSuite
private var memoryManager: TestMemoryManager = null
private var taskMemoryManager: TaskMemoryManager = null
+ private var taskContext: TaskContext = null
+
def testWithMemoryLeakDetection(name: String)(f: => Unit) {
def cleanup(): Unit = {
if (taskMemoryManager != null) {
@@ -66,9 +69,12 @@ class UnsafeFixedWidthAggregationMapSuite
val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false")
memoryManager = new TestMemoryManager(conf)
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+ taskContext = mock(classOf[TaskContext])
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
+ stageAttemptNumber = 0,
partitionId = 0,
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
@@ -109,7 +115,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
1024, // initial capacity,
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -123,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
1024, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -150,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -176,7 +182,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -224,7 +230,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -265,7 +271,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
StructType(Nil),
StructType(Nil),
- taskMemoryManager,
+ taskContext,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -310,7 +316,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
128, // initial capacity
pageSize,
false // disable perf metrics
@@ -348,7 +354,7 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- taskMemoryManager,
+ taskContext,
128, // initial capacity
pageSize,
false // disable perf metrics
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 3d869c77e9608..3a81fbaa6e104 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
@@ -116,6 +117,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
+ stageAttemptNumber = 0,
partitionId = 0,
taskAttemptId = 98456,
attemptNumber = 0,
@@ -204,4 +206,42 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
spill = true
)
}
+
+ test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") {
+ val memoryManager = new TestMemoryManager(new SparkConf())
+ val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+ val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes())
+
+ // Key/value are a unsafe rows with a single int column
+ val schema = new StructType().add("i", IntegerType)
+ val key = new UnsafeRow(1)
+ key.pointTo(new Array[Byte](32), 32)
+ key.setInt(0, 1)
+ val value = new UnsafeRow(1)
+ value.pointTo(new Array[Byte](32), 32)
+ value.setInt(0, 2)
+
+ for (_ <- 1 to 65) {
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ loc.append(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
+ }
+
+ // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap`
+ // which has duplicated keys and the number of entries exceeds its capacity.
+ try {
+ TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null))
+ new UnsafeKVExternalSorter(
+ schema,
+ schema,
+ sparkContext.env.blockManager,
+ sparkContext.env.serializerManager,
+ taskMemoryManager.pageSizeBytes(),
+ Int.MaxValue,
+ map)
+ } finally {
+ TaskContext.unset()
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 53105e0b24959..c3ecf5208d59e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
(i, converter(Row(i)))
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
- val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null)
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index a4b30a2f8cec1..183c68fd3c016 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -127,4 +129,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
"named_struct('a',id+2, 'b',id+2) as col2")
.filter("col1 = col2").count()
}
+
+ test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
+ import testImplicits._
+
+ val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int")
+ val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str")
+
+ val df = df1.join(df2, df1("key") === df2("key"))
+ .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1")
+ .select("int")
+
+ val plan = df.queryExecution.executedPlan
+ assert(!plan.find(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.children(0)
+ .isInstanceOf[SortMergeJoinExec]).isDefined)
+ assert(df.collect() === Array(Row(1), Row(2)))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
index bc9cb6ec2e771..3a9b34d7533b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
@@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte
val conf = new SparkConf()
sc = new SparkContext("local[2, 4]", "test", conf)
val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0)
- TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null))
+ TaskContext.setTaskContext(
+ new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null))
}
override def afterAll(): Unit = TaskContext.unset()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
index 8a2993bdf4b28..8a798fb444696 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
@@ -107,6 +107,7 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false")
f()
}
@@ -148,6 +149,7 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true)
sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false")
f()
}
@@ -187,6 +189,7 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false")
f()
}
@@ -225,6 +228,7 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false")
f()
}
@@ -273,6 +277,7 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false")
f()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
index 239822b72034a..a6249ce021400 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
@@ -43,6 +43,7 @@ object TPCDSQueryBenchmark {
.set("spark.driver.memory", "3g")
.set("spark.executor.memory", "3g")
.set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString)
+ .set("spark.sql.crossJoin.enabled", "true")
val spark = SparkSession.builder.config(conf).getOrCreate()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 1e6a6a8ba3362..b049b60d5b22e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -21,8 +21,9 @@ import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
-import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -414,4 +415,43 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet))
}
}
+
+ test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") {
+ withSQLConf("spark.sql.shuffle.partitions" -> "200") {
+ val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group")
+ val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id")
+ val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct()
+
+ df3.unpersist()
+ val agg_without_cache = df3.groupBy($"item").count()
+
+ df3.cache()
+ val agg_with_cache = df3.groupBy($"item").count()
+ checkAnswer(agg_without_cache, agg_with_cache)
+ }
+ }
+
+ test("SPARK-22249: IN should work also with cached DataFrame") {
+ val df = spark.range(10).cache()
+ // with an empty list
+ assert(df.filter($"id".isin()).count() == 0)
+ // with a non-empty list
+ assert(df.filter($"id".isin(2)).count() == 1)
+ assert(df.filter($"id".isin(2, 3)).count() == 2)
+ df.unpersist()
+ val dfNulls = spark.range(10).selectExpr("null as id").cache()
+ // with null as value for the attribute
+ assert(dfNulls.filter($"id".isin()).count() == 0)
+ assert(dfNulls.filter($"id".isin(2, 3)).count() == 0)
+ dfNulls.unpersist()
+ }
+
+ test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") {
+ val attribute = AttributeReference("a", IntegerType)()
+ val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY,
+ LocalTableScanExec(Seq(attribute), Nil), None)
+ val tableScanExec = InMemoryTableScanExec(Seq(attribute),
+ Seq(In(attribute, Nil)), testRelation)
+ assert(tableScanExec.partitionFilters.isEmpty)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
index 97c61dc8694bc..8a6bc62fec96c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
@@ -530,13 +530,13 @@ class DDLCommandSuite extends PlanTest {
""".stripMargin
val sql4 =
"""
- |ALTER TABLE table_name PARTITION (test, dt='2008-08-08',
+ |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08',
|country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar',
|'field.delim' = ',')
""".stripMargin
val sql5 =
"""
- |ALTER TABLE table_name PARTITION (test, dt='2008-08-08',
+ |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08',
|country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',')
""".stripMargin
val parsed1 = parser.parsePlan(sql1)
@@ -558,12 +558,12 @@ class DDLCommandSuite extends PlanTest {
tableIdent,
Some("org.apache.class"),
Some(Map("columns" -> "foo,bar", "field.delim" -> ",")),
- Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))
+ Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us")))
val expected5 = AlterTableSerDePropertiesCommand(
tableIdent,
None,
Some(Map("columns" -> "foo,bar", "field.delim" -> ",")),
- Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))
+ Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us")))
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
comparePlans(parsed3, expected3)
@@ -832,6 +832,14 @@ class DDLCommandSuite extends PlanTest {
assert(e.contains("Found duplicate keys 'a'"))
}
+ test("empty values in non-optional partition specs") {
+ val e = intercept[ParseException] {
+ parser.parsePlan(
+ "SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)")
+ }.getMessage
+ assert(e.contains("Found an empty partition key 'b'"))
+ }
+
test("drop table") {
val tableName1 = "db.tab"
val tableName2 = "tab"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index fe74ab49f91bd..0c067233a6e74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -49,7 +49,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo
protected override def generateTable(
catalog: SessionCatalog,
- name: TableIdentifier): CatalogTable = {
+ name: TableIdentifier,
+ isDataSource: Boolean = true): CatalogTable = {
val storage =
CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name)))
val metadata = new MetadataBuilder()
@@ -70,46 +71,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo
tracksPartitionsInCatalog = true)
}
- test("alter table: set location (datasource table)") {
- testSetLocation(isDatasourceTable = true)
- }
-
- test("alter table: set properties (datasource table)") {
- testSetProperties(isDatasourceTable = true)
- }
-
- test("alter table: unset properties (datasource table)") {
- testUnsetProperties(isDatasourceTable = true)
- }
-
- test("alter table: set serde (datasource table)") {
- testSetSerde(isDatasourceTable = true)
- }
-
- test("alter table: set serde partition (datasource table)") {
- testSetSerdePartition(isDatasourceTable = true)
- }
-
- test("alter table: change column (datasource table)") {
- testChangeColumn(isDatasourceTable = true)
- }
-
- test("alter table: add partition (datasource table)") {
- testAddPartitions(isDatasourceTable = true)
- }
-
- test("alter table: drop partition (datasource table)") {
- testDropPartitions(isDatasourceTable = true)
- }
-
- test("alter table: rename partition (datasource table)") {
- testRenamePartitions(isDatasourceTable = true)
- }
-
- test("drop table - data source table") {
- testDropTable(isDatasourceTable = true)
- }
-
test("create a managed Hive source table") {
assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
val tabName = "tbl"
@@ -163,7 +124,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive"
}
- protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable
+ protected def generateTable(
+ catalog: SessionCatalog,
+ name: TableIdentifier,
+ isDataSource: Boolean = true): CatalogTable
private val escapedIdentifier = "`(.+)`".r
@@ -205,8 +169,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
ignoreIfExists = false)
}
- private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = {
- catalog.createTable(generateTable(catalog, name), ignoreIfExists = false)
+ private def createTable(
+ catalog: SessionCatalog,
+ name: TableIdentifier,
+ isDataSource: Boolean = true): Unit = {
+ catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false)
}
private def createTablePartition(
@@ -223,6 +190,46 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri
}
+ test("alter table: set location (datasource table)") {
+ testSetLocation(isDatasourceTable = true)
+ }
+
+ test("alter table: set properties (datasource table)") {
+ testSetProperties(isDatasourceTable = true)
+ }
+
+ test("alter table: unset properties (datasource table)") {
+ testUnsetProperties(isDatasourceTable = true)
+ }
+
+ test("alter table: set serde (datasource table)") {
+ testSetSerde(isDatasourceTable = true)
+ }
+
+ test("alter table: set serde partition (datasource table)") {
+ testSetSerdePartition(isDatasourceTable = true)
+ }
+
+ test("alter table: change column (datasource table)") {
+ testChangeColumn(isDatasourceTable = true)
+ }
+
+ test("alter table: add partition (datasource table)") {
+ testAddPartitions(isDatasourceTable = true)
+ }
+
+ test("alter table: drop partition (datasource table)") {
+ testDropPartitions(isDatasourceTable = true)
+ }
+
+ test("alter table: rename partition (datasource table)") {
+ testRenamePartitions(isDatasourceTable = true)
+ }
+
+ test("drop table - data source table") {
+ testDropTable(isDatasourceTable = true)
+ }
+
test("the qualified path of a database is stored in the catalog") {
val catalog = spark.sessionState.catalog
@@ -695,7 +702,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
withView("testview") {
sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " +
"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " +
- s"OPTIONS (PATH '$tmpFile')")
+ s"OPTIONS (PATH '${tmpFile.toURI}')")
checkAnswer(
sql("select c1, c2 from testview order by c1 limit 1"),
@@ -707,7 +714,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
s"""
|CREATE TEMPORARY VIEW testview
|USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
- |OPTIONS (PATH '$tmpFile')
+ |OPTIONS (PATH '${tmpFile.toURI}')
""".stripMargin)
}
}
@@ -751,7 +758,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
val df = (1 to 2).map { i => (i, i.toString) }.toDF("age", "name")
df.write.insertInto("students")
spark.catalog.cacheTable("students")
- assume(spark.table("students").collect().toSeq == df.collect().toSeq, "bad test: wrong data")
+ checkAnswer(spark.table("students"), df)
assume(spark.catalog.isCached("students"), "bad test: table was not cached in the first place")
sql("ALTER TABLE students RENAME TO teachers")
sql("CREATE TABLE students (age INT, name STRING) USING parquet")
@@ -760,7 +767,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(!spark.catalog.isCached("students"))
assert(spark.catalog.isCached("teachers"))
assert(spark.table("students").collect().isEmpty)
- assert(spark.table("teachers").collect().toSeq == df.collect().toSeq)
+ checkAnswer(spark.table("teachers"), df)
}
test("rename temporary table - destination table with database name") {
@@ -793,10 +800,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
spark.range(10).createOrReplaceTempView("tab1")
sql("ALTER TABLE tab1 RENAME TO tab2")
checkAnswer(spark.table("tab2"), spark.range(10).toDF())
- intercept[NoSuchTableException] { spark.table("tab1") }
+ val e = intercept[AnalysisException](spark.table("tab1")).getMessage
+ assert(e.contains("Table or view not found"))
sql("ALTER VIEW tab2 RENAME TO tab1")
checkAnswer(spark.table("tab1"), spark.range(10).toDF())
- intercept[NoSuchTableException] { spark.table("tab2") }
+ intercept[AnalysisException] { spark.table("tab2") }
}
}
@@ -835,32 +843,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
}
- test("alter table: set location") {
- testSetLocation(isDatasourceTable = false)
- }
-
- test("alter table: set properties") {
- testSetProperties(isDatasourceTable = false)
- }
-
- test("alter table: unset properties") {
- testUnsetProperties(isDatasourceTable = false)
- }
-
- // TODO: move this test to HiveDDLSuite.scala
- ignore("alter table: set serde") {
- testSetSerde(isDatasourceTable = false)
- }
-
- // TODO: move this test to HiveDDLSuite.scala
- ignore("alter table: set serde partition") {
- testSetSerdePartition(isDatasourceTable = false)
- }
-
- test("alter table: change column") {
- testChangeColumn(isDatasourceTable = false)
- }
-
test("alter table: bucketing is not supported") {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
@@ -885,10 +867,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES")
}
- test("alter table: add partition") {
- testAddPartitions(isDatasourceTable = false)
- }
-
test("alter table: recover partitions (sequential)") {
withSQLConf("spark.rdd.parallelListingThreshold" -> "10") {
testRecoverPartitions()
@@ -957,17 +935,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')")
}
- test("alter table: drop partition") {
- testDropPartitions(isDatasourceTable = false)
- }
-
test("alter table: drop partition is not supported for views") {
assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')")
}
- test("alter table: rename partition") {
- testRenamePartitions(isDatasourceTable = false)
- }
test("show databases") {
sql("CREATE DATABASE showdb2B")
@@ -1011,18 +982,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(catalog.listTables("default") == Nil)
}
- test("drop table") {
- testDropTable(isDatasourceTable = false)
- }
-
protected def testDropTable(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
+ createTable(catalog, tableIdent, isDatasourceTable)
assert(catalog.listTables("dbx") == Seq(tableIdent))
sql("DROP TABLE dbx.tab1")
assert(catalog.listTables("dbx") == Nil)
@@ -1046,22 +1013,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead"))
}
- private def convertToDatasourceTable(
- catalog: SessionCatalog,
- tableIdent: TableIdentifier): Unit = {
- catalog.alterTable(catalog.getTableMetadata(tableIdent).copy(
- provider = Some("csv")))
- assert(catalog.getTableMetadata(tableIdent).provider == Some("csv"))
- }
-
protected def testSetProperties(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
+ createTable(catalog, tableIdent, isDatasourceTable)
def getProps: Map[String, String] = {
if (isUsingHiveMetastore) {
normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties
@@ -1084,13 +1043,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
+ createTable(catalog, tableIdent, isDatasourceTable)
def getProps: Map[String, String] = {
if (isUsingHiveMetastore) {
normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties
@@ -1121,15 +1080,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testSetLocation(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val partSpec = Map("a" -> "1", "b" -> "2")
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
+ createTable(catalog, tableIdent, isDatasourceTable)
createTablePartition(catalog, partSpec, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined)
assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty)
assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined)
@@ -1171,13 +1130,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testSetSerde(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
+ createTable(catalog, tableIdent, isDatasourceTable)
def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = {
val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties
if (isUsingHiveMetastore) {
@@ -1187,8 +1146,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
}
if (isUsingHiveMetastore) {
- assert(catalog.getTableMetadata(tableIdent).storage.serde ==
- Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ val expectedSerde = if (isDatasourceTable) {
+ "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
+ } else {
+ "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
+ }
+ assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde))
} else {
assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty)
}
@@ -1229,18 +1192,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val spec = Map("a" -> "1", "b" -> "2")
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
+ createTable(catalog, tableIdent, isDatasourceTable)
createTablePartition(catalog, spec, tableIdent)
createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent)
createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent)
createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = {
val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties
if (isUsingHiveMetastore) {
@@ -1250,8 +1213,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
}
if (isUsingHiveMetastore) {
- assert(catalog.getPartition(tableIdent, spec).storage.serde ==
- Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ val expectedSerde = if (isDatasourceTable) {
+ "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
+ } else {
+ "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
+ }
+ assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde))
} else {
assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty)
}
@@ -1295,6 +1262,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testAddPartitions(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1", "b" -> "5")
@@ -1303,11 +1273,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
val part4 = Map("a" -> "4", "b" -> "8")
val part5 = Map("a" -> "9", "b" -> "9")
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
+ createTable(catalog, tableIdent, isDatasourceTable)
createTablePartition(catalog, part1, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1))
// basic add partition
@@ -1354,6 +1321,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testDropPartitions(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1", "b" -> "5")
@@ -1362,7 +1332,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
val part4 = Map("a" -> "4", "b" -> "8")
val part5 = Map("a" -> "9", "b" -> "9")
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
+ createTable(catalog, tableIdent, isDatasourceTable)
createTablePartition(catalog, part1, tableIdent)
createTablePartition(catalog, part2, tableIdent)
createTablePartition(catalog, part3, tableIdent)
@@ -1370,9 +1340,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
createTablePartition(catalog, part5, tableIdent)
assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
Set(part1, part2, part3, part4, part5))
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
// basic drop partition
sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')")
@@ -1407,20 +1374,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1", "b" -> "q")
val part2 = Map("a" -> "2", "b" -> "c")
val part3 = Map("a" -> "3", "b" -> "p")
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
+ createTable(catalog, tableIdent, isDatasourceTable)
createTablePartition(catalog, part1, tableIdent)
createTablePartition(catalog, part2, tableIdent)
createTablePartition(catalog, part3, tableIdent)
assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3))
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
// basic rename partition
sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')")
@@ -1451,14 +1418,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
protected def testChangeColumn(isDatasourceTable: Boolean): Unit = {
+ if (!isUsingHiveMetastore) {
+ assert(isDatasourceTable, "InMemoryCatalog only supports data source tables")
+ }
val catalog = spark.sessionState.catalog
val resolver = spark.sessionState.conf.resolver
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
- createTable(catalog, tableIdent)
- if (isDatasourceTable) {
- convertToDatasourceTable(catalog, tableIdent)
- }
+ createTable(catalog, tableIdent, isDatasourceTable)
def getMetadata(colName: String): Metadata = {
val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field =>
resolver(field.name, colName)
@@ -1468,6 +1435,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
// Ensure that change column will preserve other metadata fields.
sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'")
assert(getMetadata("col1").getString("key") == "value")
+ assert(getMetadata("col1").getString("comment") == "this is col1")
}
test("drop build-in function") {
@@ -1601,13 +1569,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
test("drop current database") {
- sql("CREATE DATABASE temp")
- sql("USE temp")
- sql("DROP DATABASE temp")
- val e = intercept[AnalysisException] {
+ withDatabase("temp") {
+ sql("CREATE DATABASE temp")
+ sql("USE temp")
+ sql("DROP DATABASE temp")
+ val e = intercept[AnalysisException] {
sql("CREATE TABLE t (a INT, b INT) USING parquet")
}.getMessage
- assert(e.contains("Database 'temp' not found"))
+ assert(e.contains("Database 'temp' not found"))
+ }
}
test("drop default database") {
@@ -1837,22 +1807,25 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
checkAnswer(spark.table("tbl"), Row(1))
val defaultTablePath = spark.sessionState.catalog
.getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get
-
- sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'")
- spark.catalog.refreshTable("tbl")
- // SET LOCATION won't move data from previous table path to new table path.
- assert(spark.table("tbl").count() == 0)
- // the previous table path should be still there.
- assert(new File(defaultTablePath).exists())
-
- sql("INSERT INTO tbl SELECT 2")
- checkAnswer(spark.table("tbl"), Row(2))
- // newly inserted data will go to the new table path.
- assert(dir.listFiles().nonEmpty)
-
- sql("DROP TABLE tbl")
- // the new table path will be removed after DROP TABLE.
- assert(!dir.exists())
+ try {
+ sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'")
+ spark.catalog.refreshTable("tbl")
+ // SET LOCATION won't move data from previous table path to new table path.
+ assert(spark.table("tbl").count() == 0)
+ // the previous table path should be still there.
+ assert(new File(defaultTablePath).exists())
+
+ sql("INSERT INTO tbl SELECT 2")
+ checkAnswer(spark.table("tbl"), Row(2))
+ // newly inserted data will go to the new table path.
+ assert(dir.listFiles().nonEmpty)
+
+ sql("DROP TABLE tbl")
+ // the new table path will be removed after DROP TABLE.
+ assert(!dir.exists())
+ } finally {
+ Utils.deleteRecursively(new File(defaultTablePath))
+ }
}
}
}
@@ -1864,7 +1837,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
s"""
|CREATE TABLE t(a string, b int)
|USING parquet
- |OPTIONS(path "$dir")
+ |OPTIONS(path "${dir.toURI}")
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
@@ -1882,12 +1855,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
checkAnswer(spark.table("t"), Row("c", 1) :: Nil)
val newDirFile = new File(dir, "x")
- val newDir = newDirFile.getAbsolutePath
+ val newDir = newDirFile.toURI
spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'")
spark.sessionState.catalog.refreshTable(TableIdentifier("t"))
val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table1.location == new URI(newDir))
+ assert(table1.location == newDir)
assert(!newDirFile.exists)
spark.sql("INSERT INTO TABLE t SELECT 'c', 1")
@@ -1905,7 +1878,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
|CREATE TABLE t(a int, b int, c int, d int)
|USING parquet
|PARTITIONED BY(a, b)
- |LOCATION "$dir"
+ |LOCATION "${dir.toURI}"
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
@@ -1931,7 +1904,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
s"""
|CREATE TABLE t(a string, b int)
|USING parquet
- |OPTIONS(path "$dir")
+ |OPTIONS(path "${dir.toURI}")
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
@@ -1960,7 +1933,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
|CREATE TABLE t(a int, b int, c int, d int)
|USING parquet
|PARTITIONED BY(a, b)
- |LOCATION "$dir"
+ |LOCATION "${dir.toURI}"
""".stripMargin)
spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4")
checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil)
@@ -1977,7 +1950,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
test("create datasource table with a non-existing location") {
withTable("t", "t1") {
withTempPath { dir =>
- spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'")
+ spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '${dir.toURI}'")
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
@@ -1989,7 +1962,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}
// partition table
withTempPath { dir =>
- spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'")
+ spark.sql(
+ s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '${dir.toURI}'")
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
@@ -2014,7 +1988,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
s"""
|CREATE TABLE t
|USING parquet
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
|AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
@@ -2030,7 +2004,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
|CREATE TABLE t1
|USING parquet
|PARTITIONED BY(a, b)
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
|AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
@@ -2047,6 +2021,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars =>
test(s"data source table:partition column name containing $specialChars") {
+ // On Windows, it looks colon in the file name is illegal by default. See
+ // https://support.microsoft.com/en-us/help/289627
+ assume(!Utils.isWindows || specialChars != "a:b")
+
withTable("t") {
withTempDir { dir =>
spark.sql(
@@ -2054,14 +2032,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
|CREATE TABLE t(a string, `$specialChars` string)
|USING parquet
|PARTITIONED BY(`$specialChars`)
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
""".stripMargin)
assert(dir.listFiles().isEmpty)
spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1")
val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2"
val partFile = new File(dir, partEscaped)
- assert(partFile.listFiles().length >= 1)
+ assert(partFile.listFiles().nonEmpty)
checkAnswer(spark.table("t"), Row("1", "2") :: Nil)
}
}
@@ -2070,15 +2048,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
Seq("a b", "a:b", "a%b").foreach { specialChars =>
test(s"location uri contains $specialChars for datasource table") {
+ // On Windows, it looks colon in the file name is illegal by default. See
+ // https://support.microsoft.com/en-us/help/289627
+ assume(!Utils.isWindows || specialChars != "a:b")
+
withTable("t", "t1") {
withTempDir { dir =>
val loc = new File(dir, specialChars)
loc.mkdir()
+ // The parser does not recognize the backslashes on Windows as they are.
+ // These currently should be escaped.
+ val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\")
spark.sql(
s"""
|CREATE TABLE t(a string)
|USING parquet
- |LOCATION '$loc'
+ |LOCATION '$escapedLoc'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
@@ -2087,19 +2072,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(loc.listFiles().isEmpty)
spark.sql("INSERT INTO TABLE t SELECT 1")
- assert(loc.listFiles().length >= 1)
+ assert(loc.listFiles().nonEmpty)
checkAnswer(spark.table("t"), Row("1") :: Nil)
}
withTempDir { dir =>
val loc = new File(dir, specialChars)
loc.mkdir()
+ // The parser does not recognize the backslashes on Windows as they are.
+ // These currently should be escaped.
+ val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\")
spark.sql(
s"""
|CREATE TABLE t1(a string, b string)
|USING parquet
|PARTITIONED BY(b)
- |LOCATION '$loc'
+ |LOCATION '$escapedLoc'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
@@ -2109,15 +2097,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(loc.listFiles().isEmpty)
spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1")
val partFile = new File(loc, "b=2")
- assert(partFile.listFiles().length >= 1)
+ assert(partFile.listFiles().nonEmpty)
checkAnswer(spark.table("t1"), Row("1", "2") :: Nil)
spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1")
val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14")
assert(!partFile1.exists())
- val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14")
- assert(partFile2.listFiles().length >= 1)
- checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil)
+
+ if (!Utils.isWindows) {
+ // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows.
+ val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14")
+ assert(partFile2.listFiles().nonEmpty)
+ checkAnswer(
+ spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil)
+ }
}
}
}
@@ -2125,11 +2118,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
Seq("a b", "a:b", "a%b").foreach { specialChars =>
test(s"location uri contains $specialChars for database") {
- try {
+ // On Windows, it looks colon in the file name is illegal by default. See
+ // https://support.microsoft.com/en-us/help/289627
+ assume(!Utils.isWindows || specialChars != "a:b")
+
+ withDatabase ("tmpdb") {
withTable("t") {
withTempDir { dir =>
val loc = new File(dir, specialChars)
- spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'")
+ // The parser does not recognize the backslashes on Windows as they are.
+ // These currently should be escaped.
+ val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\")
+ spark.sql(s"CREATE DATABASE tmpdb LOCATION '$escapedLoc'")
spark.sql("USE tmpdb")
import testImplicits._
@@ -2140,8 +2140,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(tblloc.listFiles().nonEmpty)
}
}
- } finally {
- spark.sql("DROP DATABASE IF EXISTS tmpdb")
}
}
}
@@ -2150,11 +2148,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
withTable("t", "t1") {
withTempDir { dir =>
assert(!dir.getAbsolutePath.startsWith("file:/"))
+ // The parser does not recognize the backslashes on Windows as they are.
+ // These currently should be escaped.
+ val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\")
spark.sql(
s"""
|CREATE TABLE t(a string)
|USING parquet
- |LOCATION '$dir'
+ |LOCATION '$escapedDir'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.location.toString.startsWith("file:/"))
@@ -2162,12 +2163,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
withTempDir { dir =>
assert(!dir.getAbsolutePath.startsWith("file:/"))
+ // The parser does not recognize the backslashes on Windows as they are.
+ // These currently should be escaped.
+ val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\")
spark.sql(
s"""
|CREATE TABLE t1(a string, b string)
|USING parquet
|PARTITIONED BY(b)
- |LOCATION '$dir'
+ |LOCATION '$escapedDir'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
assert(table.location.toString.startsWith("file:/"))
@@ -2279,17 +2283,27 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}.getMessage
assert(e.contains("Found duplicate column(s)"))
} else {
- if (isUsingHiveMetastore) {
- // hive catalog will still complains that c1 is duplicate column name because hive
- // identifiers are case insensitive.
- val e = intercept[AnalysisException] {
- sql("ALTER TABLE t1 ADD COLUMNS (C1 string)")
- }.getMessage
- assert(e.contains("HiveException"))
- } else {
- sql("ALTER TABLE t1 ADD COLUMNS (C1 string)")
- assert(spark.table("t1").schema
- .equals(new StructType().add("c1", IntegerType).add("C1", StringType)))
+ sql("ALTER TABLE t1 ADD COLUMNS (C1 string)")
+ assert(spark.table("t1").schema ==
+ new StructType().add("c1", IntegerType).add("C1", StringType))
+ }
+ }
+ }
+ }
+
+ test(s"basic DDL using locale tr - caseSensitive $caseSensitive") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") {
+ withLocale("tr") {
+ val dbName = "DaTaBaSe_I"
+ withDatabase(dbName) {
+ sql(s"CREATE DATABASE $dbName")
+ sql(s"USE $dbName")
+
+ val tabName = "tAb_I"
+ withTable(tabName) {
+ sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET")
+ sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1")
+ checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
new file mode 100644
index 0000000000000..f20aded169e44
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.sources
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataSourceStrategySuite extends PlanTest with SharedSQLContext {
+
+ test("translate simple expression") {
+ val attrInt = 'cint.int
+ val attrStr = 'cstr.string
+
+ testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1)))
+ testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1)))
+
+ testTranslateFilter(EqualNullSafe(attrStr, Literal(null)),
+ Some(sources.EqualNullSafe("cstr", null)))
+ testTranslateFilter(EqualNullSafe(Literal(null), attrStr),
+ Some(sources.EqualNullSafe("cstr", null)))
+
+ testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan("cint", 1)))
+ testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 1)))
+
+ testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 1)))
+ testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 1)))
+
+ testTranslateFilter(GreaterThanOrEqual(attrInt, 1), Some(sources.GreaterThanOrEqual("cint", 1)))
+ testTranslateFilter(GreaterThanOrEqual(1, attrInt), Some(sources.LessThanOrEqual("cint", 1)))
+
+ testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1)))
+ testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1)))
+
+ testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3))))
+
+ testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3))))
+
+ testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint")))
+ testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint")))
+
+ // cint > 1 AND cint < 10
+ testTranslateFilter(And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)),
+ Some(sources.And(
+ sources.GreaterThan("cint", 1),
+ sources.LessThan("cint", 10))))
+
+ // cint >= 8 OR cint <= 2
+ testTranslateFilter(Or(
+ GreaterThanOrEqual(attrInt, 8),
+ LessThanOrEqual(attrInt, 2)),
+ Some(sources.Or(
+ sources.GreaterThanOrEqual("cint", 8),
+ sources.LessThanOrEqual("cint", 2))))
+
+ testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
+ Some(sources.Not(sources.GreaterThanOrEqual("cint", 8))))
+
+ testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith("cstr", "a")))
+
+ testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith("cstr", "a")))
+
+ testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains("cstr", "a")))
+ }
+
+ test("translate complex expression") {
+ val attrInt = 'cint.int
+
+ // ABS(cint) - 2 <= 1
+ testTranslateFilter(LessThanOrEqual(
+ // Expressions are not supported
+ // Functions such as 'Abs' are not supported
+ Subtract(Abs(attrInt), 2), 1), None)
+
+ // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100)
+ testTranslateFilter(Or(
+ And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ And(
+ GreaterThan(attrInt, 50),
+ LessThan(attrInt, 100))),
+ Some(sources.Or(
+ sources.And(
+ sources.GreaterThan("cint", 1),
+ sources.LessThan("cint", 10)),
+ sources.And(
+ sources.GreaterThan("cint", 50),
+ sources.LessThan("cint", 100)))))
+
+ // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data source
+ // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
+ testTranslateFilter(Or(
+ And(
+ GreaterThan(attrInt, 1),
+ // Functions such as 'Abs' are not supported
+ LessThan(Abs(attrInt), 10)
+ ),
+ And(
+ GreaterThan(attrInt, 50),
+ LessThan(attrInt, 100))), None)
+
+ // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100))
+ testTranslateFilter(Not(And(
+ Or(
+ LessThanOrEqual(attrInt, 1),
+ // Functions such as 'Abs' are not supported
+ GreaterThanOrEqual(Abs(attrInt), 10)
+ ),
+ Or(
+ LessThanOrEqual(attrInt, 50),
+ GreaterThanOrEqual(attrInt, 100)))), None)
+
+ // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10)
+ testTranslateFilter(Or(
+ Or(
+ EqualTo(attrInt, 1),
+ EqualTo(attrInt, 10)
+ ),
+ Or(
+ GreaterThan(attrInt, 0),
+ LessThan(attrInt, -10))),
+ Some(sources.Or(
+ sources.Or(
+ sources.EqualTo("cint", 1),
+ sources.EqualTo("cint", 10)),
+ sources.Or(
+ sources.GreaterThan("cint", 0),
+ sources.LessThan("cint", -10)))))
+
+ // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
+ testTranslateFilter(Or(
+ Or(
+ EqualTo(attrInt, 1),
+ // Functions such as 'Abs' are not supported
+ EqualTo(Abs(attrInt), 10)
+ ),
+ Or(
+ GreaterThan(attrInt, 0),
+ LessThan(attrInt, -10))), None)
+
+ // In end-to-end testing, conjunctive predicate should has been split
+ // before reaching DataSourceStrategy.translateFilter.
+ // This is for UT purpose to test each [[case]].
+ // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL)
+ testTranslateFilter(And(
+ And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ And(
+ EqualTo(attrInt, 6),
+ IsNotNull(attrInt))),
+ Some(sources.And(
+ sources.And(
+ sources.GreaterThan("cint", 1),
+ sources.LessThan("cint", 10)),
+ sources.And(
+ sources.EqualTo("cint", 6),
+ sources.IsNotNull("cint")))))
+
+ // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
+ testTranslateFilter(And(
+ And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ And(
+ // Functions such as 'Abs' are not supported
+ EqualTo(Abs(attrInt), 6),
+ IsNotNull(attrInt))), None)
+
+ // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
+ testTranslateFilter(And(
+ Or(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ Or(
+ EqualTo(attrInt, 6),
+ IsNotNull(attrInt))),
+ Some(sources.And(
+ sources.Or(
+ sources.GreaterThan("cint", 1),
+ sources.LessThan("cint", 10)),
+ sources.Or(
+ sources.EqualTo("cint", 6),
+ sources.IsNotNull("cint")))))
+
+ // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
+ testTranslateFilter(And(
+ Or(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ Or(
+ // Functions such as 'Abs' are not supported
+ EqualTo(Abs(attrInt), 6),
+ IsNotNull(attrInt))), None)
+ }
+
+ /**
+ * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]]
+ * then verify against the given [[sources.Filter]].
+ */
+ def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = {
+ assertResult(result) {
+ DataSourceStrategy.translateFilter(catalystFilter)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
index a9511cbd9e4cf..b4616826e40b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem}
import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator}
@@ -236,6 +237,17 @@ class FileIndexSuite extends SharedSQLContext {
val fileStatusCache = FileStatusCache.getOrCreate(spark)
fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray)
}
+
+ test("SPARK-20367 - properly unescape column names in inferPartitioning") {
+ withTempPath { path =>
+ val colToUnescape = "Column/#%'?"
+ spark
+ .range(1)
+ .select(col("id").as(colToUnescape), col("id"))
+ .write.partitionBy(colToUnescape).parquet(path.getAbsolutePath)
+ assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape))
+ }
+ }
}
class FakeParentPathFileSystem extends RawLocalFileSystem {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index f36162858bf7a..fa3c69612704d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.Utils
class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper {
import testImplicits._
- protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1")
+ protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1")
test("unpartitioned table, single partition") {
val table =
@@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
val fileCatalog = new InMemoryFileIndex(
sparkSession = spark,
- rootPaths = Seq(new Path(tempDir)),
+ rootPathsSpecified = Seq(new Path(tempDir)),
parameters = Map.empty[String, String],
partitionSchema = None)
// This should not fail.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala
new file mode 100644
index 0000000000000..cf340d0ab4a76
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.test.SharedSQLContext
+
+class SaveIntoDataSourceCommandSuite extends SharedSQLContext {
+
+ test("simpleString is redacted") {
+ val URL = "connection.url"
+ val PASS = "123"
+ val DRIVER = "mydriver"
+
+ val simpleString = SaveIntoDataSourceCommand(
+ spark.range(1).logicalPlan,
+ "jdbc",
+ Nil,
+ Map("password" -> PASS, "url" -> URL, "driver" -> DRIVER),
+ SaveMode.ErrorIfExists).treeString(true)
+
+ assert(!simpleString.contains(URL))
+ assert(!simpleString.contains(PASS))
+ assert(simpleString.contains(DRIVER))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 352dba79a4c08..89d9b69dec7ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -261,10 +261,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test for DROPMALFORMED parsing mode") {
- Seq(false, true).foreach { wholeFile =>
+ Seq(false, true).foreach { multiLine =>
val cars = spark.read
.format("csv")
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.options(Map("header" -> "true", "mode" -> "dropmalformed"))
.load(testFile(carsFile))
@@ -284,11 +284,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test for FAILFAST parsing mode") {
- Seq(false, true).foreach { wholeFile =>
+ Seq(false, true).foreach { multiLine =>
val exception = intercept[SparkException] {
spark.read
.format("csv")
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.options(Map("header" -> "true", "mode" -> "failfast"))
.load(testFile(carsFile)).collect()
}
@@ -990,13 +990,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") {
- Seq(false, true).foreach { wholeFile =>
+ Seq(false, true).foreach { multiLine =>
val schema = new StructType().add("a", IntegerType).add("b", TimestampType)
// We use `PERMISSIVE` mode by default if invalid string is given.
val df1 = spark
.read
.option("mode", "abcd")
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.schema(schema)
.csv(testFile(valueMalformedFile))
checkAnswer(df1,
@@ -1011,7 +1011,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.read
.option("mode", "Permissive")
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.schema(schemaWithCorrField1)
.csv(testFile(valueMalformedFile))
checkAnswer(df2,
@@ -1028,7 +1028,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.read
.option("mode", "permissive")
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.schema(schemaWithCorrField2)
.csv(testFile(valueMalformedFile))
checkAnswer(df3,
@@ -1041,7 +1041,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.read
.option("mode", "PERMISSIVE")
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.schema(schema.add(columnNameOfCorruptRecord, IntegerType))
.csv(testFile(valueMalformedFile))
.collect
@@ -1073,7 +1073,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val df = spark.read
.option("header", true)
- .option("wholeFile", true)
+ .option("multiLine", true)
.csv(path.getAbsolutePath)
// Check if headers have new lines in the names.
@@ -1096,10 +1096,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("Empty file produces empty dataframe with empty schema") {
- Seq(false, true).foreach { wholeFile =>
+ Seq(false, true).foreach { multiLine =>
val df = spark.read.format("csv")
.option("header", true)
- .option("wholeFile", wholeFile)
+ .option("multiLine", multiLine)
.load(testFile(emptyFile))
assert(df.schema === spark.emptyDataFrame.schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 2ab03819964be..f8eb5c569f9ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json
import java.io.{File, StringWriter}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
+import java.util.Locale
import com.fasterxml.jackson.core.JsonFactory
import org.apache.hadoop.fs.{Path, PathFilter}
@@ -1803,7 +1804,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(new File(path).listFiles().exists(_.getName.endsWith(".gz")))
- val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDF = spark.read.option("multiLine", true).json(path)
val jsonDir = new File(dir, "json").getCanonicalPath
jsonDF.coalesce(1).write
.option("compression", "gZiP")
@@ -1825,7 +1826,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.write
.text(path)
- val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDF = spark.read.option("multiLine", true).json(path)
val jsonDir = new File(dir, "json").getCanonicalPath
jsonDF.coalesce(1).write.json(jsonDir)
@@ -1854,7 +1855,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.write
.text(path)
- val jsonDF = spark.read.option("wholeFile", true).json(path)
+ val jsonDF = spark.read.option("multiLine", true).json(path)
// no corrupt record column should be created
assert(jsonDF.schema === StructType(Seq()))
// only the first object should be read
@@ -1875,7 +1876,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.write
.text(path)
- val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path)
+ val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path)
assert(jsonDF.count() === corruptRecordCount)
assert(jsonDF.schema === new StructType()
.add("_corrupt_record", StringType)
@@ -1906,7 +1907,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.write
.text(path)
- val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path)
+ val jsonDF = spark.read.option("multiLine", true).option("mode", "DROPMALFORMED").json(path)
checkAnswer(jsonDF, Seq(Row("test")))
}
}
@@ -1929,7 +1930,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// `FAILFAST` mode should throw an exception for corrupt records.
val exceptionOne = intercept[SparkException] {
spark.read
- .option("wholeFile", true)
+ .option("multiLine", true)
.option("mode", "FAILFAST")
.json(path)
}
@@ -1937,7 +1938,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val exceptionTwo = intercept[SparkException] {
spark.read
- .option("wholeFile", true)
+ .option("multiLine", true)
.option("mode", "FAILFAST")
.schema(schema)
.json(path)
@@ -1978,4 +1979,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
}
}
+
+ test("SPARK-18772: Parse special floats correctly") {
+ val jsons = Seq(
+ """{"a": "NaN"}""",
+ """{"a": "Infinity"}""",
+ """{"a": "-Infinity"}""")
+
+ // positive cases
+ val checks: Seq[Double => Boolean] = Seq(
+ _.isNaN,
+ _.isPosInfinity,
+ _.isNegInfinity)
+
+ Seq(FloatType, DoubleType).foreach { dt =>
+ jsons.zip(checks).foreach { case (json, check) =>
+ val ds = spark.read
+ .schema(StructType(Seq(StructField("a", dt))))
+ .json(Seq(json).toDS())
+ .select($"a".cast(DoubleType)).as[Double]
+ assert(check(ds.first()))
+ }
+ }
+
+ // negative cases
+ Seq(FloatType, DoubleType).foreach { dt =>
+ val lowerCasedJsons = jsons.map(_.toLowerCase(Locale.ROOT))
+ // The special floats are case-sensitive so these cases below throw exceptions.
+ lowerCasedJsons.foreach { lowerCasedJson =>
+ val e = intercept[SparkException] {
+ spark.read
+ .option("mode", "FAILFAST")
+ .schema(StructType(Seq(StructField("a", dt))))
+ .json(Seq(lowerCasedJson).toDS())
+ .collect()
+ }
+ assert(e.getMessage.contains("Cannot parse"))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala
new file mode 100644
index 0000000000000..caa4f6d70c6a9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.io.FileNotFoundException
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
+import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
+
+import org.apache.spark.{LocalSparkContext, SparkFunSuite}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * Test logic related to choice of output committers.
+ */
+class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils
+ with LocalSparkContext {
+
+ private val PARQUET_COMMITTER = classOf[ParquetOutputCommitter].getCanonicalName
+
+ protected var spark: SparkSession = _
+
+ /**
+ * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled.
+ */
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark = SparkSession.builder()
+ .master("local-cluster[2,1,1024]")
+ .appName("testing")
+ .getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ if (spark != null) {
+ spark.stop()
+ spark = null
+ }
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ test("alternative output committer, merge schema") {
+ writeDataFrame(MarkingFileOutput.COMMITTER, summary = true, check = true)
+ }
+
+ test("alternative output committer, no merge schema") {
+ writeDataFrame(MarkingFileOutput.COMMITTER, summary = false, check = true)
+ }
+
+ test("Parquet output committer, merge schema") {
+ writeDataFrame(PARQUET_COMMITTER, summary = true, check = false)
+ }
+
+ test("Parquet output committer, no merge schema") {
+ writeDataFrame(PARQUET_COMMITTER, summary = false, check = false)
+ }
+
+ /**
+ * Write a trivial dataframe as Parquet, using the given committer
+ * and job summary option.
+ * @param committer committer to use
+ * @param summary create a job summary
+ * @param check look for a marker file
+ * @return if a marker file was sought, it's file status.
+ */
+ private def writeDataFrame(
+ committer: String,
+ summary: Boolean,
+ check: Boolean): Option[FileStatus] = {
+ var result: Option[FileStatus] = None
+ withSQLConf(
+ SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer,
+ ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) {
+ withTempPath { dest =>
+ val df = spark.createDataFrame(Seq((1, "4"), (2, "2")))
+ val destPath = new Path(dest.toURI)
+ df.write.format("parquet").save(destPath.toString)
+ if (check) {
+ result = Some(MarkingFileOutput.checkMarker(
+ destPath,
+ spark.sparkContext.hadoopConfiguration))
+ }
+ }
+ }
+ result
+ }
+}
+
+/**
+ * A file output committer which explicitly touches a file "marker"; this
+ * is how tests can verify that this committer was used.
+ * @param outputPath output path
+ * @param context task context
+ */
+private class MarkingFileOutputCommitter(
+ outputPath: Path,
+ context: TaskAttemptContext) extends FileOutputCommitter(outputPath, context) {
+
+ override def commitJob(context: JobContext): Unit = {
+ super.commitJob(context)
+ MarkingFileOutput.touch(outputPath, context.getConfiguration)
+ }
+}
+
+private object MarkingFileOutput {
+
+ val COMMITTER = classOf[MarkingFileOutputCommitter].getCanonicalName
+
+ /**
+ * Touch the marker.
+ * @param outputPath destination directory
+ * @param conf configuration to create the FS with
+ */
+ def touch(outputPath: Path, conf: Configuration): Unit = {
+ outputPath.getFileSystem(conf).create(new Path(outputPath, "marker")).close()
+ }
+
+ /**
+ * Get the file status of the marker
+ *
+ * @param outputPath destination directory
+ * @param conf configuration to create the FS with
+ * @return the status of the marker
+ * @throws FileNotFoundException if the marker is absent
+ */
+ def checkMarker(outputPath: Path, conf: Configuration): FileStatus = {
+ outputPath.getFileSystem(conf).getFileStatus(new Path(outputPath, "marker"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 9a3328fcecee8..98427cfe3031c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
+import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
val path = s"${dir.getCanonicalPath}/table"
(1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)
- Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
- withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
- val accu = new LongAccumulator
- accu.register(sparkContext, Some("numRowGroups"))
+ Seq(true, false).foreach { enablePushDown =>
+ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) {
+ val accu = new NumRowGroupsAcc
+ sparkContext.register(accu)
val df = spark.read.parquet(path).filter("a < 100")
df.foreachPartition(_.foreach(v => accu.add(0)))
df.collect
- val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
- assert(numRowGroups.isDefined)
- assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
+ if (enablePushDown) {
+ assert(accu.value == 0)
+ } else {
+ assert(accu.value > 0)
+ }
AccumulatorContext.remove(accu.id)
}
}
@@ -536,4 +538,43 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// scalastyle:on nonascii
}
}
+
+ test("SPARK-20364: Disable Parquet predicate pushdown for fields having dots in the names") {
+ import testImplicits._
+
+ Seq(true, false).foreach { vectorized =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString,
+ SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString) {
+ withTempPath { path =>
+ Seq(Some(1), None).toDF("col.dots").write.parquet(path.getAbsolutePath)
+ val readBack = spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NOT NULL")
+ assert(readBack.count() == 1)
+ }
+ }
+ }
+ }
+}
+
+class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
+ private var _sum = 0
+
+ override def isZero: Boolean = _sum == 0
+
+ override def copy(): AccumulatorV2[Integer, Integer] = {
+ val acc = new NumRowGroupsAcc()
+ acc._sum = _sum
+ acc
+ }
+
+ override def reset(): Unit = _sum = 0
+
+ override def add(v: Integer): Unit = _sum += v
+
+ override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match {
+ case a: NumRowGroupsAcc => _sum += a._sum
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def value: Integer = _sum
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index b4f3de9961209..7225693e50279 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -1022,4 +1022,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
}
}
}
+
+ test("SPARK-22109: Resolve type conflicts between strings and timestamps in partition column") {
+ val df = Seq(
+ (1, "2015-01-01 00:00:00"),
+ (2, "2014-01-01 00:00:00"),
+ (3, "blah")).toDF("i", "str")
+
+ withTempPath { path =>
+ df.write.format("parquet").partitionBy("str").save(path.getAbsolutePath)
+ checkAnswer(spark.read.load(path.getAbsolutePath), df)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index c36609586c807..2efff3f57d7d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -23,7 +23,7 @@ import java.sql.Timestamp
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.parquet.hadoop.ParquetOutputFormat
-import org.apache.spark.SparkException
+import org.apache.spark.{DebugFilesystem, SparkException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
@@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
}
}
+ /**
+ * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop
+ * to increase the chance of failure
+ */
+ ignore("SPARK-20407 ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test") {
+ def testIgnoreCorruptFiles(): Unit = {
+ withTempDir { dir =>
+ val basePath = dir.getCanonicalPath
+ spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString)
+ spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString)
+ spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString)
+ val df = spark.read.parquet(
+ new Path(basePath, "first").toString,
+ new Path(basePath, "second").toString,
+ new Path(basePath, "third").toString)
+ checkAnswer(
+ df,
+ Seq(Row(0), Row(1)))
+ }
+ }
+
+ for (i <- 1 to 100) {
+ DebugFilesystem.clearOpenStreams()
+ withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") {
+ val exception = intercept[SparkException] {
+ testIgnoreCorruptFiles()
+ }
+ assert(exception.getMessage().contains("is not a Parquet file"))
+ }
+ DebugFilesystem.assertNoOpenStreams()
+ }
+ }
+
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
withTempPath { dir =>
val basePath = dir.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 26c45e092dc65..afb8ced53e25c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -157,7 +157,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
test("broadcast hint in SQL") {
- import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join}
+ import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join}
spark.range(10).createOrReplaceTempView("t")
spark.range(10).createOrReplaceTempView("u")
@@ -170,12 +170,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution
.optimizedPlan
- assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint])
- assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint])
- assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint])
- assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint])
- assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint])
- assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint])
+ assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
+ assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
+ assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
+ assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
+ assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
+ assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index ede63fea9606f..9c9e9dc07e91b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.CompactBuffer
@@ -253,6 +253,59 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
map.free()
}
+ test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+
+ val key = 0L
+ // the page array is initialized with length 1 << 17 (1M bytes),
+ // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug
+ val bigStr = UTF8String.fromString("x" * (1 << 19))
+
+ map.append(key, unsafeProj(InternalRow(bigStr)))
+ map.optimize()
+
+ val resultRow = new UnsafeRow(1)
+ assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr)
+ map.free()
+ }
+
+ test("SPARK-24809: Serializing LongToUnsafeRowMap in executor may result in data error") {
+ val unsafeProj = UnsafeProjection.create(Array[DataType](LongType))
+ val originalMap = new LongToUnsafeRowMap(mm, 1)
+
+ val key1 = 1L
+ val value1 = 4852306286022334418L
+
+ val key2 = 2L
+ val value2 = 8813607448788216010L
+
+ originalMap.append(key1, unsafeProj(InternalRow(value1)))
+ originalMap.append(key2, unsafeProj(InternalRow(value2)))
+ originalMap.optimize()
+
+ val ser = sparkContext.env.serializer.newInstance()
+ // Simulate serialize/deserialize twice on driver and executor
+ val firstTimeSerialized = ser.deserialize[LongToUnsafeRowMap](ser.serialize(originalMap))
+ val secondTimeSerialized =
+ ser.deserialize[LongToUnsafeRowMap](ser.serialize(firstTimeSerialized))
+
+ val resultRow = new UnsafeRow(1)
+ assert(secondTimeSerialized.getValue(key1, resultRow).getLong(0) === value1)
+ assert(secondTimeSerialized.getValue(key2, resultRow).getLong(0) === value2)
+
+ originalMap.free()
+ firstTimeSerialized.free()
+ secondTimeSerialized.free()
+ }
+
test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 2ce7db6a22c01..79d1fbfa3f072 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -143,6 +143,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
+ test("ObjectHashAggregate metrics") {
+ // Assume the execution plan is
+ // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1)
+ // -> ObjectHashAggregate(nodeId = 0)
+ val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions
+ testSparkPlanMetrics(df, 1, Map(
+ 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 2L)),
+ 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 1L)))
+ )
+
+ // 2 partitions and each partition contains 2 keys
+ val df2 = testData2.groupBy('a).agg(collect_set('a))
+ testSparkPlanMetrics(df2, 1, Map(
+ 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 4L)),
+ 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 3L)))
+ )
+ }
+
test("Sort metrics") {
// Assume the execution plan is
// WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
@@ -270,6 +288,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
}
+ test("SortMergeJoin(left-anti) metrics") {
+ val anti = testData2.filter("a > 2")
+ withTempView("antiData") {
+ anti.createOrReplaceTempView("antiData")
+ val df = spark.sql(
+ "SELECT * FROM testData2 ANTI JOIN antiData ON testData2.a = antiData.a")
+ testSparkPlanMetrics(df, 1, Map(
+ 0L -> ("SortMergeJoin", Map("number of output rows" -> 4L)))
+ )
+ }
+ }
+
test("save metrics") {
withTempPath { file =>
val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala
index 20ac06f048c6f..3d480b148db55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.test.SharedSQLContext
class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext {
/** To avoid caching of FS objects */
- override protected val sparkConf =
- new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true")
+ override protected def sparkConf =
+ super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true")
import CompactibleFileStreamLog._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
index 662c4466b21b2..48e70e48b1799 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
@@ -38,8 +38,8 @@ import org.apache.spark.util.UninterruptibleThread
class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
/** To avoid caching of FS objects */
- override protected val sparkConf =
- new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true")
+ override protected def sparkConf =
+ super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true")
private implicit def toOption[A](a: A): Option[A] = Option(a)
@@ -259,6 +259,23 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
fm.rename(path2, path3)
}
}
+
+ test("verifyBatchIds") {
+ import HDFSMetadataLog.verifyBatchIds
+ verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L))
+ verifyBatchIds(Seq(1L), Some(1L), Some(1L))
+ verifyBatchIds(Seq(1L, 2L, 3L), None, Some(3L))
+ verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), None)
+ verifyBatchIds(Seq(1L, 2L, 3L), None, None)
+
+ intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), None))
+ intercept[IllegalStateException](verifyBatchIds(Seq(), None, Some(1L)))
+ intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), Some(1L)))
+ intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), None))
+ intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), None, Some(5L)))
+ intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), Some(5L)))
+ intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L)))
+ }
}
/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
new file mode 100644
index 0000000000000..bdba536425a43
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.TimeUnit
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
+import org.apache.spark.util.ManualClock
+
+class RateSourceSuite extends StreamTest {
+
+ import testImplicits._
+
+ case class AdvanceRateManualClock(seconds: Long) extends AddData {
+ override def addData(query: Option[StreamExecution]): (Source, Offset) = {
+ assert(query.nonEmpty)
+ val rateSource = query.get.logicalPlan.collect {
+ case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] =>
+ source.asInstanceOf[RateStreamSource]
+ }.head
+ rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
+ (rateSource, rateSource.getOffset.get)
+ }
+ }
+
+ test("basic") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("useManualClock", "true")
+ .load()
+ testStream(input)(
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
+ StopStream,
+ StartStream(),
+ // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
+ AdvanceRateManualClock(seconds = 2),
+ CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
+ )
+ }
+
+ test("uniform distribution of event timestamps") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "1500")
+ .option("useManualClock", "true")
+ .load()
+ .as[(java.sql.Timestamp, Long)]
+ .map(v => (v._1.getTime, v._2))
+ val expectedAnswer = (0 until 1500).map { v =>
+ (math.round(v * (1000.0 / 1500)), v)
+ }
+ testStream(input)(
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch(expectedAnswer: _*)
+ )
+ }
+
+ test("valueAtSecond") {
+ import RateStreamSource._
+
+ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
+ assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
+
+ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0)
+ assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1)
+ assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3)
+ assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8)
+
+ assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0)
+ assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2)
+ assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6)
+ assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12)
+ assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20)
+ assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30)
+ }
+
+ test("rampUpTime") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("rampUpTime", "4s")
+ .option("useManualClock", "true")
+ .load()
+ .as[(java.sql.Timestamp, Long)]
+ .map(v => (v._1.getTime, v._2))
+ testStream(input)(
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch({
+ Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
+ }: _*), // speed = 6
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
+ AdvanceRateManualClock(seconds = 1),
+ // Now we should reach full speed
+ CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10
+ )
+ }
+
+ test("numPartitions") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("numPartitions", "6")
+ .option("useManualClock", "true")
+ .load()
+ .select(spark_partition_id())
+ .distinct()
+ testStream(input)(
+ AdvanceRateManualClock(1),
+ CheckLastBatch((0 until 6): _*)
+ )
+ }
+
+ testQuietly("overflow") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", Long.MaxValue.toString)
+ .option("useManualClock", "true")
+ .load()
+ .select(spark_partition_id())
+ .distinct()
+ testStream(input)(
+ AdvanceRateManualClock(2),
+ ExpectFailure[ArithmeticException](t => {
+ Seq("overflow", "rowsPerSecond").foreach { msg =>
+ assert(t.getMessage.contains(msg))
+ }
+ })
+ )
+ }
+
+ testQuietly("illegal option values") {
+ def testIllegalOptionValue(
+ option: String,
+ value: String,
+ expectedMessages: Seq[String]): Unit = {
+ val e = intercept[StreamingQueryException] {
+ spark.readStream
+ .format("rate")
+ .option(option, value)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ .awaitTermination()
+ }
+ assert(e.getCause.isInstanceOf[IllegalArgumentException])
+ for (msg <- expectedMessages) {
+ assert(e.getCause.getMessage.contains(msg))
+ }
+ }
+
+ testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive"))
+ testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index ebb7422765ebb..cc09b2d5b7763 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -314,7 +314,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") {
val conf = new Configuration()
conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName)
- conf.set("fs.default.name", "fake:///")
+ conf.set("fs.defaultFS", "fake:///")
val provider = newStoreProvider(hadoopConf = conf)
provider.getStore(0).commit()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerMemorySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerMemorySuite.scala
new file mode 100644
index 0000000000000..24a09f37c645c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerMemorySuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.ui
+
+import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark.LocalSparkContext.withSpark
+import org.apache.spark.internal.config
+import org.apache.spark.sql.{Column, SparkSession}
+import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.functions._
+
+class SQLListenerMemorySuite extends SparkFunSuite {
+
+ test("SPARK-22471 - _stageIdToStageMetrics grows too large on long executions") {
+ quietly {
+ val conf = new SparkConf()
+ .setMaster("local[*]")
+ .setAppName("MemoryLeakTest")
+ /* Don't retry the tasks to run this test quickly */
+ .set(config.MAX_TASK_FAILURES, 1)
+ .set("spark.ui.retainedStages", "50")
+ withSpark(new SparkContext(conf)) { sc =>
+ SparkSession.sqlListener.set(null)
+ val spark = new SparkSession(sc)
+ import spark.implicits._
+
+ val sample = List(
+ (1, 10),
+ (2, 20),
+ (3, 30)
+ ).toDF("id", "value")
+
+ /* Some complex computation with many stages. */
+ val joins = 1 to 100
+ val summedCol: Column = joins
+ .map(j => col(s"value$j"))
+ .reduce(_ + _)
+ val res = joins
+ .map { j =>
+ sample.select($"id", $"value" * j as s"value$j")
+ }
+ .reduce(_.join(_, "id"))
+ .select($"id", summedCol as "value")
+ .groupBy("id")
+ .agg(sum($"value") as "value")
+ .orderBy("id")
+ res.collect()
+
+ sc.listenerBus.waitUntilEmpty(10000)
+ assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 50)
+ }
+ }
+ }
+
+ test("no memory leak") {
+ quietly {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly
+ .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
+ withSpark(new SparkContext(conf)) { sc =>
+ SparkSession.sqlListener.set(null)
+ val spark = new SparkSession(sc)
+ import spark.implicits._
+ // Run 100 successful executions and 100 failed executions.
+ // Each execution only has one job and one stage.
+ for (i <- 0 until 100) {
+ val df = Seq(
+ (1, 1),
+ (2, 2)
+ ).toDF()
+ df.collect()
+ try {
+ df.foreach(_ => throw new RuntimeException("Oops"))
+ } catch {
+ case e: SparkException => // This is expected for a failed job
+ }
+ }
+ sc.listenerBus.waitUntilEmpty(10000)
+ assert(spark.sharedState.listener.getCompletedExecutions.size <= 50)
+ assert(spark.sharedState.listener.getFailedExecutions.size <= 50)
+ // 50 for successful executions and 50 for failed executions
+ assert(spark.sharedState.listener.executionIdToData.size <= 100)
+ assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100)
+ assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100)
+ }
+ }
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index e6cd41e4facf1..23420a5af59a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -24,14 +24,12 @@ import org.mockito.Mockito.mock
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.internal.config
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler._
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.test.SharedSQLContext
@@ -485,46 +483,3 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe
sc.emptyRDD
}
}
-
-
-class SQLListenerMemoryLeakSuite extends SparkFunSuite {
-
- test("no memory leak") {
- quietly {
- val conf = new SparkConf()
- .setMaster("local")
- .setAppName("test")
- .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly
- .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
- val sc = new SparkContext(conf)
- try {
- SparkSession.sqlListener.set(null)
- val spark = new SparkSession(sc)
- import spark.implicits._
- // Run 100 successful executions and 100 failed executions.
- // Each execution only has one job and one stage.
- for (i <- 0 until 100) {
- val df = Seq(
- (1, 1),
- (2, 2)
- ).toDF()
- df.collect()
- try {
- df.foreach(_ => throw new RuntimeException("Oops"))
- } catch {
- case e: SparkException => // This is expected for a failed job
- }
- }
- sc.listenerBus.waitUntilEmpty(10000)
- assert(spark.sharedState.listener.getCompletedExecutions.size <= 50)
- assert(spark.sharedState.listener.getFailedExecutions.size <= 50)
- // 50 for successful executions and 50 for failed executions
- assert(spark.sharedState.listener.executionIdToData.size <= 100)
- assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100)
- assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100)
- } finally {
- sc.stop()
- }
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
new file mode 100644
index 0000000000000..7e0f747d2cb61
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.vectorized
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
+ private def withVector(
+ vector: ColumnVector)(
+ block: ColumnVector => Unit): Unit = {
+ try block(vector) finally vector.close()
+ }
+
+ private def testVectors(
+ name: String,
+ size: Int,
+ dt: DataType)(
+ block: ColumnVector => Unit): Unit = {
+ test(name) {
+ withVector(new OnHeapColumnVector(size, dt))(block)
+ withVector(new OffHeapColumnVector(size, dt))(block)
+ }
+ }
+
+ testVectors("boolean", 10, BooleanType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendBoolean(i % 2 == 0)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getBoolean(i) === (i % 2 == 0))
+ }
+ }
+
+ testVectors("byte", 10, ByteType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendByte(i.toByte)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getByte(i) === (i.toByte))
+ }
+ }
+
+ testVectors("short", 10, ShortType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendShort(i.toShort)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getShort(i) === (i.toShort))
+ }
+ }
+
+ testVectors("int", 10, IntegerType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendInt(i)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getInt(i) === i)
+ }
+ }
+
+ testVectors("long", 10, LongType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendLong(i)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getLong(i) === i)
+ }
+ }
+
+ testVectors("float", 10, FloatType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendFloat(i.toFloat)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getFloat(i) === i.toFloat)
+ }
+ }
+
+ testVectors("double", 10, DoubleType) { testVector =>
+ (0 until 10).foreach { i =>
+ testVector.appendDouble(i.toDouble)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getDouble(i) === i.toDouble)
+ }
+ }
+
+ testVectors("string", 10, StringType) { testVector =>
+ (0 until 10).map { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ testVector.appendByteArray(utf8, 0, utf8.length)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ assert(array.getUTF8String(i) === UTF8String.fromString(s"str$i"))
+ }
+ }
+
+ testVectors("binary", 10, BinaryType) { testVector =>
+ (0 until 10).map { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ testVector.appendByteArray(utf8, 0, utf8.length)
+ }
+
+ val array = new ColumnVector.Array(testVector)
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ assert(array.getBinary(i) === utf8)
+ }
+ }
+
+ val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true)
+ testVectors("array", 10, arrayType) { testVector =>
+
+ val data = testVector.arrayData()
+ var i = 0
+ while (i < 6) {
+ data.putInt(i, i)
+ i += 1
+ }
+
+ // Populate it with arrays [0], [1, 2], [], [3, 4, 5]
+ testVector.putArray(0, 0, 1)
+ testVector.putArray(1, 1, 2)
+ testVector.putArray(2, 3, 0)
+ testVector.putArray(3, 3, 3)
+
+ val array = new ColumnVector.Array(testVector)
+
+ assert(array.getArray(0).toIntArray() === Array(0))
+ assert(array.getArray(1).asInstanceOf[ArrayData].toIntArray() === Array(1, 2))
+ assert(array.getArray(2).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int])
+ assert(array.getArray(3).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5))
+ }
+
+ val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType)
+ testVectors("struct", 10, structType) { testVector =>
+ val c1 = testVector.getChildColumn(0)
+ val c2 = testVector.getChildColumn(1)
+ c1.putInt(0, 123)
+ c2.putDouble(0, 3.45)
+ c1.putInt(1, 456)
+ c2.putDouble(1, 5.67)
+
+ val array = new ColumnVector.Array(testVector)
+
+ assert(array.getStruct(0, 2).asInstanceOf[ColumnarBatch.Row].getInt(0) === 123)
+ assert(array.getStruct(0, 2).asInstanceOf[ColumnarBatch.Row].getDouble(1) === 3.45)
+ assert(array.getStruct(1, 2).asInstanceOf[ColumnarBatch.Row].getInt(0) === 456)
+ assert(array.getStruct(1, 2).asInstanceOf[ColumnarBatch.Row].getDouble(1) === 5.67)
+ }
+
+ test("[SPARK-22092] off-heap column vector reallocation corrupts array data") {
+ withVector(new OffHeapColumnVector(8, arrayType)) { testVector =>
+ val data = testVector.arrayData()
+ (0 until 8).foreach(i => data.putInt(i, i))
+ (0 until 8).foreach(i => testVector.putArray(i, i, 1))
+
+ // Increase vector's capacity and reallocate the data to new bigger buffers.
+ testVector.reserve(16)
+
+ // Check that none of the values got lost/overwritten.
+ val array = new ColumnVector.Array(testVector)
+ (0 until 8).foreach { i =>
+ assert(array.getArray(i).toIntArray() === Array(i))
+ }
+ }
+ }
+
+ test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") {
+ withVector(new OffHeapColumnVector(8, structType)) { testVector =>
+ (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i))
+ testVector.reserve(16)
+ (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0)))
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 8184d7d909f4b..31f013983bce0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -116,6 +116,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(v._1 == Platform.getByte(null, addr + v._2))
}
}
+ column.close()
}}
}
@@ -317,6 +318,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(v._1 == Platform.getLong(null, addr + 8 * v._2))
}
}
+ column.close()
}}
}
@@ -443,6 +445,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
column.reset()
assert(column.arrayData().elementsAppended == 0)
+ column.close()
}}
}
@@ -498,6 +501,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
column.putArray(0, 0, array.length)
assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]]
=== array)
+ column.close()
}}
}
@@ -528,6 +532,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val s2 = column.getStruct(1)
assert(s2.getInt(0) == 456)
assert(s2.getDouble(1) == 5.67)
+ column.close()
}}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index a283ff971adcd..948f179f5e8f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -270,4 +270,15 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
val e2 = intercept[AnalysisException](spark.conf.unset(SCHEMA_STRING_LENGTH_THRESHOLD.key))
assert(e2.message.contains("Cannot modify the value of a static config"))
}
+
+ test("SPARK-21588 SQLContext.getConf(key, null) should return null") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ assert("1" == spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key, null))
+ assert("1" == spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key, ""))
+ }
+
+ assert(spark.conf.getOption("spark.sql.nonexistent").isEmpty)
+ assert(null == spark.conf.get("spark.sql.nonexistent", null))
+ assert("" == spark.conf.get("spark.sql.nonexistent", ""))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 5bd36ec25ccb0..8e31c09482b08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -96,6 +96,15 @@ class JDBCSuite extends SparkFunSuite
| partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
""".stripMargin.replaceAll("\n", " "))
+ sql(
+ s"""
+ |CREATE OR REPLACE TEMPORARY VIEW partsoverflow
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
+ | partitionColumn 'THEID', lowerBound '-9223372036854775808',
+ | upperBound '9223372036854775807', numPartitions '3')
+ """.stripMargin.replaceAll("\n", " "))
+
conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, "
+ "d SMALLINT, e BIGINT)").executeUpdate()
conn.prepareStatement("insert into test.inttypes values (1, false, 3, 4, 1234567890123)"
@@ -275,10 +284,13 @@ class JDBCSuite extends SparkFunSuite
// This is a test to reflect discussion in SPARK-12218.
// The older versions of spark have this kind of bugs in parquet data source.
- val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')")
- val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')")
+ val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')")
assert(df1.collect.toSet === Set(Row("mary", 2)))
- assert(df2.collect.toSet === Set(Row("mary", 2)))
+
+ // SPARK-22548: Incorrect nested AND expression pushed down to JDBC data source
+ val df2 = sql("SELECT * FROM foobar " +
+ "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')")
+ assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2)))
def checkNotPushdown(df: DataFrame): DataFrame = {
val parentPlan = df.queryExecution.executedPlan
@@ -367,6 +379,12 @@ class JDBCSuite extends SparkFunSuite
assert(ids(2) === 3)
}
+ test("overflow of partition bound difference does not give negative stride") {
+ val df = sql("SELECT * FROM partsoverflow")
+ checkNumPartitions(df, expectedNumPartitions = 3)
+ assert(df.collect().length == 3)
+ }
+
test("Register JDBC query with renamed fields") {
// Regression test for bug SPARK-7345
sql(
@@ -397,6 +415,28 @@ class JDBCSuite extends SparkFunSuite
assert(e.contains("Invalid value `-1` for parameter `fetchsize`"))
}
+ test("Missing partition columns") {
+ withView("tempPeople") {
+ val e = intercept[IllegalArgumentException] {
+ sql(
+ s"""
+ |CREATE OR REPLACE TEMPORARY VIEW tempPeople
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (
+ | url 'jdbc:h2:mem:testdb0;user=testUser;password=testPass',
+ | dbtable 'TEST.PEOPLE',
+ | lowerBound '0',
+ | upperBound '52',
+ | numPartitions '53',
+ | fetchSize '10000' )
+ """.stripMargin.replaceAll("\n", " "))
+ }.getMessage
+ assert(e.contains("When reading JDBC data sources, users need to specify all or none " +
+ "for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and " +
+ "'numPartitions'"))
+ }
+ }
+
test("Basic API with FetchSize") {
(0 to 4).foreach { size =>
val properties = new Properties()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index bf1fd160704fa..d7ae45680fe56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -323,8 +323,9 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
.option("partitionColumn", "foo")
.save()
}.getMessage
- assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
- " and 'numPartitions' are required."))
+ assert(e.contains("When reading JDBC data sources, users need to specify all or none " +
+ "for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and " +
+ "'numPartitions'"))
}
test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 9b65419dba234..ba0ca666b5c14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
originalDataFrame: DataFrame): Unit = {
// This test verifies parts of the plan. Disable whole stage codegen.
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ val strategy = DataSourceStrategy(spark.sessionState.conf)
val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k")
val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
// Limit: bucket pruning only works when the bucket column has one and only one column
@@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
- matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
+ matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value))
}
// Filter could hide the bug in bucket pruning. Thus, skipping all the filters
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
index 85ba33e58a787..b5fb740b6eb77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
@@ -19,26 +19,39 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{StringType, StructField, StructType}
+import org.apache.spark.sql.types._
// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
- test("data sources with the same name") {
- intercept[RuntimeException] {
+ test("data sources with the same name - internal data sources") {
+ val e = intercept[AnalysisException] {
spark.read.format("Fluet da Bomb").load()
}
+ assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb"))
+ }
+
+ test("data sources with the same name - internal data source/external data source") {
+ assert(spark.read.format("datasource").load().schema ==
+ StructType(Seq(StructField("longType", LongType, nullable = false))))
+ }
+
+ test("data sources with the same name - external data sources") {
+ val e = intercept[AnalysisException] {
+ spark.read.format("Fake external source").load()
+ }
+ assert(e.getMessage.contains("Multiple sources found for Fake external source"))
}
test("load data source from format alias") {
- spark.read.format("gathering quorum").load().schema ==
- StructType(Seq(StructField("stringType", StringType, nullable = false)))
+ assert(spark.read.format("gathering quorum").load().schema ==
+ StructType(Seq(StructField("stringType", StringType, nullable = false))))
}
test("specify full classname with duplicate formats") {
- spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
- .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
+ assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
+ .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))))
}
test("should fail to load ORC without Hive Support") {
@@ -63,7 +76,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister {
}
}
-class FakeSourceTwo extends RelationProvider with DataSourceRegister {
+class FakeSourceTwo extends RelationProvider with DataSourceRegister {
def shortName(): String = "Fluet da Bomb"
@@ -72,7 +85,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister {
override def sqlContext: SQLContext = cont
override def schema: StructType =
- StructType(Seq(StructField("stringType", StringType, nullable = false)))
+ StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
}
}
@@ -88,3 +101,16 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister {
StructType(Seq(StructField("stringType", StringType, nullable = false)))
}
}
+
+class FakeSourceFour extends RelationProvider with DataSourceRegister {
+
+ def shortName(): String = "datasource"
+
+ override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
+ new BaseRelation {
+ override def sqlContext: SQLContext = cont
+
+ override def schema: StructType =
+ StructType(Seq(StructField("longType", LongType, nullable = false)))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
index b16c9f8fc96b2..735e07c21373a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal}
import org.apache.spark.sql.execution.datasources.DataSourceAnalysis
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
@@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
}
Seq(true, false).foreach { caseSensitive =>
- val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive))
+ val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
+ def cast(e: Expression, dt: DataType): Expression = {
+ Cast(e, dt, Option(conf.sessionLocalTimeZone))
+ }
+ val rule = DataSourceAnalysis(conf)
test(
s"convertStaticPartitions only handle INSERT having at least static partitions " +
s"(caseSensitive: $caseSensitive)") {
@@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
if (!caseSensitive) {
val nonPartitionedAttributes = Seq('e.int, 'f.int)
val expected = nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")),
@@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
{
val nonPartitionedAttributes = Seq('e.int, 'f.int)
val expected = nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")),
@@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
// Test the case having a single static partition column.
{
val nonPartitionedAttributes = Seq('e.int, 'f.int)
- val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType))
+ val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1")),
@@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
val dynamicPartitionAttributes = Seq('g.int)
val expected =
nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType)) ++
+ Seq(cast(Literal("1"), IntegerType)) ++
dynamicPartitionAttributes
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 5a0388ec1d1db..c902b0afcce6c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -326,7 +326,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic
assert(ColumnsRequired.set === requiredColumnNames)
val table = spark.table("oneToTenFiltered")
- val relation = table.queryExecution.logical.collectFirst {
+ val relation = table.queryExecution.analyzed.collectFirst {
case LogicalRelation(r, _, _) => r
}.get
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 2eae66dda88de..41abff2a5da25 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -345,4 +345,25 @@ class InsertSuite extends DataSourceTest with SharedSQLContext {
)
}
}
+
+ test("SPARK-21203 wrong results of insertion of Array of Struct") {
+ val tabName = "tab1"
+ withTable(tabName) {
+ spark.sql(
+ """
+ |CREATE TABLE `tab1`
+ |(`custom_fields` ARRAY>)
+ |USING parquet
+ """.stripMargin)
+ spark.sql(
+ """
+ |INSERT INTO `tab1`
+ |SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b'))
+ """.stripMargin)
+
+ checkAnswer(
+ spark.sql("SELECT custom_fields.id, custom_fields.value FROM tab1"),
+ Row(Array(1, 2), Array("a", "b")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
index 6dd4847ead738..c25c3f62158cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
@@ -92,12 +92,12 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
s"""
|CREATE TABLE src
|USING ${classOf[TestOptionsSource].getCanonicalName}
- |OPTIONS (PATH '$p')
+ |OPTIONS (PATH '${p.toURI}')
|AS SELECT 1
""".stripMargin)
assert(
spark.table("src").schema.head.metadata.getString("path") ==
- p.getAbsolutePath)
+ p.toURI.toString)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala
new file mode 100644
index 0000000000000..0dfd75e709123
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala
@@ -0,0 +1,64 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.fakesource
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
+import org.apache.spark.sql.types._
+
+
+// Note that the package name is intendedly mismatched in order to resemble external data sources
+// and test the detection for them.
+class FakeExternalSourceOne extends RelationProvider with DataSourceRegister {
+
+ def shortName(): String = "Fake external source"
+
+ override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
+ new BaseRelation {
+ override def sqlContext: SQLContext = cont
+
+ override def schema: StructType =
+ StructType(Seq(StructField("stringType", StringType, nullable = false)))
+ }
+}
+
+class FakeExternalSourceTwo extends RelationProvider with DataSourceRegister {
+
+ def shortName(): String = "Fake external source"
+
+ override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
+ new BaseRelation {
+ override def sqlContext: SQLContext = cont
+
+ override def schema: StructType =
+ StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
+ }
+}
+
+class FakeExternalSourceThree extends RelationProvider with DataSourceRegister {
+
+ def shortName(): String = "datasource"
+
+ override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
+ new BaseRelation {
+ override def sqlContext: SQLContext = cont
+
+ override def schema: StructType =
+ StructType(Seq(StructField("byteType", ByteType, nullable = false)))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
index a15c2cff930fc..e858b7d9998a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
@@ -268,4 +268,17 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
CheckLastBatch(7)
)
}
+
+ test("SPARK-21546: dropDuplicates should ignore watermark when it's not a key") {
+ val input = MemoryStream[(Int, Int)]
+ val df = input.toDS.toDF("id", "time")
+ .withColumn("time", $"time".cast("timestamp"))
+ .withWatermark("time", "1 second")
+ .dropDuplicates("id")
+ .select($"id", $"time".cast("long"))
+ testStream(df)(
+ AddData(input, 1 -> 1, 1 -> 2, 2 -> 2),
+ CheckLastBatch(1 -> 1, 2 -> 2)
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index fd850a7365e20..4f19fa0bb4a97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -21,7 +21,7 @@ import java.{util => ju}
import java.text.SimpleDateFormat
import java.util.Date
-import org.scalatest.BeforeAndAfter
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.OutputMode._
-class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Logging {
+class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging {
import testImplicits._
@@ -38,6 +38,43 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin
sqlContext.streams.active.foreach(_.stop())
}
+ test("EventTimeStats") {
+ val epsilon = 10E-6
+
+ val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5)
+ stats.add(80L)
+ stats.max should be (100)
+ stats.min should be (10)
+ stats.avg should be (30.0 +- epsilon)
+ stats.count should be (6)
+
+ val stats2 = EventTimeStats(80L, 5L, 15.0, 4)
+ stats.merge(stats2)
+ stats.max should be (100)
+ stats.min should be (5)
+ stats.avg should be (24.0 +- epsilon)
+ stats.count should be (10)
+ }
+
+ test("EventTimeStats: avg on large values") {
+ val epsilon = 10E-6
+ val largeValue = 10000000000L // 10B
+ // Make sure `largeValue` will cause overflow if we use a Long sum to calc avg.
+ assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue))
+ val stats =
+ EventTimeStats(max = largeValue, min = largeValue, avg = largeValue, count = largeValue - 1)
+ stats.add(largeValue)
+ stats.avg should be (largeValue.toDouble +- epsilon)
+
+ val stats2 = EventTimeStats(
+ max = largeValue + 1,
+ min = largeValue,
+ avg = largeValue + 1,
+ count = largeValue)
+ stats.merge(stats2)
+ stats.avg should be ((largeValue + 0.5) +- epsilon)
+ }
+
test("error on bad column") {
val inputData = MemoryStream[Int].toDF()
val e = intercept[AnalysisException] {
@@ -344,6 +381,44 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin
assert(eventTimeColumns(0).name === "second")
}
+ test("EventTime watermark should be ignored in batch query.") {
+ val df = testData
+ .withColumn("eventTime", $"key".cast("timestamp"))
+ .withWatermark("eventTime", "1 minute")
+ .select("eventTime")
+ .as[Long]
+
+ checkDataset[Long](df, 1L to 100L: _*)
+ }
+
+ test("SPARK-21565: watermark operator accepts attributes from replacement") {
+ withTempDir { dir =>
+ dir.delete()
+
+ val df = Seq(("a", 100.0, new java.sql.Timestamp(100L)))
+ .toDF("symbol", "price", "eventTime")
+ df.write.json(dir.getCanonicalPath)
+
+ val input = spark.readStream.schema(df.schema)
+ .json(dir.getCanonicalPath)
+
+ val groupEvents = input
+ .withWatermark("eventTime", "2 seconds")
+ .groupBy("symbol", "eventTime")
+ .agg(count("price") as 'count)
+ .select("symbol", "eventTime", "count")
+ val q = groupEvents.writeStream
+ .outputMode("append")
+ .format("console")
+ .start()
+ try {
+ q.processAllAvailable()
+ } finally {
+ q.stop()
+ }
+ }
+ }
+
private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 1211242b9fbb4..bb6a27803bb20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming
import java.util.Locale
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -62,6 +64,35 @@ class FileStreamSinkSuite extends StreamTest {
}
}
+ test("SPARK-21167: encode and decode path correctly") {
+ val inputData = MemoryStream[String]
+ val ds = inputData.toDS()
+
+ val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
+ val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath
+
+ val query = ds.map(s => (s, s.length))
+ .toDF("value", "len")
+ .writeStream
+ .partitionBy("value")
+ .option("checkpointLocation", checkpointDir)
+ .format("parquet")
+ .start(outputDir)
+
+ try {
+ // The output is partitoned by "value", so the value will appear in the file path.
+ // This is to test if we handle spaces in the path correctly.
+ inputData.addData("hello world")
+ failAfter(streamingTimeout) {
+ query.processAllAvailable()
+ }
+ val outputDf = spark.read.parquet(outputDir)
+ checkDatasetUnorderly(outputDf.as[(Int, String)], ("hello world".length, "hello world"))
+ } finally {
+ query.stop()
+ }
+ }
+
test("partitioned writing and batch reading") {
val inputData = MemoryStream[Int]
val ds = inputData.toDS()
@@ -145,6 +176,43 @@ class FileStreamSinkSuite extends StreamTest {
}
}
+ test("partitioned writing and batch reading with 'basePath'") {
+ withTempDir { outputDir =>
+ withTempDir { checkpointDir =>
+ val outputPath = outputDir.getAbsolutePath
+ val inputData = MemoryStream[Int]
+ val ds = inputData.toDS()
+
+ var query: StreamingQuery = null
+
+ try {
+ query =
+ ds.map(i => (i, -i, i * 1000))
+ .toDF("id1", "id2", "value")
+ .writeStream
+ .partitionBy("id1", "id2")
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .format("parquet")
+ .start(outputPath)
+
+ inputData.addData(1, 2, 3)
+ failAfter(streamingTimeout) {
+ query.processAllAvailable()
+ }
+
+ val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*")
+ checkDatasetUnorderly(
+ readIn.as[(Int, Int, Int)],
+ (1000, 1, -1), (2000, 2, -2), (3000, 3, -3))
+ } finally {
+ if (query != null) {
+ query.stop()
+ }
+ }
+ }
+ }
+ }
+
// This tests whether FileStreamSink works with aggregations. Specifically, it tests
// whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to
// to execute the trigger for writing data to file sink. See SPARK-18440 for more details.
@@ -266,4 +334,22 @@ class FileStreamSinkSuite extends StreamTest {
}
}
}
+
+ test("FileStreamSink.ancestorIsMetadataDirectory()") {
+ val hadoopConf = spark.sparkContext.hadoopConfiguration
+ def assertAncestorIsMetadataDirectory(path: String): Unit =
+ assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf))
+ def assertAncestorIsNotMetadataDirectory(path: String): Unit =
+ assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf))
+
+ assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}")
+ assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/")
+ assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}")
+ assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/")
+ assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c")
+ assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/")
+
+ assertAncestorIsNotMetadataDirectory(s"/a/b/c")
+ assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 2108b118bf059..e2ec690d90e52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -1314,6 +1314,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
val metadataLog =
new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath)
assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0))))
+ assert(metadataLog.add(1, Array(FileEntry(s"$scheme:///file2", 200L, 0))))
val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil,
dir.getAbsolutePath, Map.empty)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 85aa7dbe9ed86..d7370642d08b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
assert(state.hasRemoved === shouldBeRemoved)
}
+ // === Tests for state in streaming queries ===
// Updating empty state
- state = new GroupStateImpl[String](None)
+ state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false)
testState(None)
state.update("")
testState(Some(""), shouldBeUpdated = true)
// Updating exiting state
- state = new GroupStateImpl[String](Some("2"))
+ state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false)
testState(Some("2"))
state.update("3")
testState(Some("3"), shouldBeUpdated = true)
@@ -99,24 +100,34 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
test("GroupState - setTimeout**** with NoTimeout") {
- for (initState <- Seq(None, Some(5))) {
- // for different initial state
- implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false)
- testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
- testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+ for (initValue <- Seq(None, Some(5))) {
+ val states = Seq(
+ GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false),
+ GroupStateImpl.createForBatch(NoTimeout)
+ )
+ for (state <- states) {
+ // for streaming queries
+ testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+ testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+ // for batch queries
+ testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+ testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+ }
}
}
test("GroupState - setTimeout**** with ProcessingTimeTimeout") {
- implicit var state: GroupStateImpl[Int] = null
-
- state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
+ // for streaming queries
+ var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
+ None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
- testTimeoutDurationNotAllowed[IllegalStateException](state)
+ state.setTimeoutDuration(500)
+ assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
state.update(5)
- assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+ assert(state.getTimeoutTimestamp === 1500) // does not change
state.setTimeoutDuration(1000)
assert(state.getTimeoutTimestamp === 2000)
state.setTimeoutDuration("2 second")
@@ -124,19 +135,38 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
state.remove()
+ assert(state.getTimeoutTimestamp === 3000) // does not change
+ state.setTimeoutDuration(500) // can still be set
+ assert(state.getTimeoutTimestamp === 1500)
+ testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+ // for batch queries
+ state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
- testTimeoutDurationNotAllowed[IllegalStateException](state)
+ state.setTimeoutDuration(500)
+ testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+ state.update(5)
+ state.setTimeoutDuration(1000)
+ state.setTimeoutDuration("2 second")
+ testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
+
+ state.remove()
+ state.setTimeoutDuration(500)
testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
}
test("GroupState - setTimeout**** with EventTimeTimeout") {
- implicit val state = new GroupStateImpl[Int](
- None, 1000, 1000, EventTimeTimeout, hasTimedOut = false)
+ var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
+ None, 1000, 1000, EventTimeTimeout, false)
+
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
- testTimeoutTimestampNotAllowed[IllegalStateException](state)
+ state.setTimeoutTimestamp(5000)
+ assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state
state.update(5)
+ assert(state.getTimeoutTimestamp === 5000) // does not change
state.setTimeoutTimestamp(10000)
assert(state.getTimeoutTimestamp === 10000)
state.setTimeoutTimestamp(new Date(20000))
@@ -144,9 +174,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
state.remove()
+ assert(state.getTimeoutTimestamp === 20000)
+ state.setTimeoutTimestamp(5000)
+ assert(state.getTimeoutTimestamp === 5000) // can be set after removing state
+ testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+
+ // for batch queries
+ state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]]
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
- testTimeoutTimestampNotAllowed[IllegalStateException](state)
+ state.setTimeoutTimestamp(5000)
+
+ state.update(5)
+ state.setTimeoutTimestamp(10000)
+ state.setTimeoutTimestamp(new Date(20000))
+ testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
+
+ state.remove()
+ state.setTimeoutTimestamp(5000)
+ testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
}
test("GroupState - illegal params to setTimeout****") {
@@ -154,47 +200,86 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
// Test setTimeout****() with illegal values
def testIllegalTimeout(body: => Unit): Unit = {
- intercept[IllegalArgumentException] { body }
+ intercept[IllegalArgumentException] {
+ body
+ }
assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
}
- state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
- testIllegalTimeout { state.setTimeoutDuration(-1000) }
- testIllegalTimeout { state.setTimeoutDuration(0) }
- testIllegalTimeout { state.setTimeoutDuration("-2 second") }
- testIllegalTimeout { state.setTimeoutDuration("-1 month") }
- testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") }
+ state = GroupStateImpl.createForStreaming(
+ Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false)
+ testIllegalTimeout {
+ state.setTimeoutDuration(-1000)
+ }
+ testIllegalTimeout {
+ state.setTimeoutDuration(0)
+ }
+ testIllegalTimeout {
+ state.setTimeoutDuration("-2 second")
+ }
+ testIllegalTimeout {
+ state.setTimeoutDuration("-1 month")
+ }
+ testIllegalTimeout {
+ state.setTimeoutDuration("1 month -1 day")
+ }
- state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
- testIllegalTimeout { state.setTimeoutTimestamp(-10000) }
- testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") }
- testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") }
- testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") }
- testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) }
- testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") }
- testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") }
- testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") }
+ state = GroupStateImpl.createForStreaming(
+ Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false)
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(-10000)
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(10000, "-3 second")
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(10000, "-1 month")
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(10000, "1 month -1 day")
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(new Date(-10000))
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(new Date(-10000), "-3 second")
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(new Date(-10000), "-1 month")
+ }
+ testIllegalTimeout {
+ state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day")
+ }
}
test("GroupState - hasTimedOut") {
for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) {
+ // for streaming queries
for (initState <- Seq(None, Some(5))) {
- val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false)
+ val state1 = GroupStateImpl.createForStreaming(
+ initState, 1000, 1000, timeoutConf, hasTimedOut = false)
assert(state1.hasTimedOut === false)
- val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true)
+
+ val state2 = GroupStateImpl.createForStreaming(
+ initState, 1000, 1000, timeoutConf, hasTimedOut = true)
assert(state2.hasTimedOut === true)
}
+
+ // for batch queries
+ assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false)
}
}
test("GroupState - primitive type") {
- var intState = new GroupStateImpl[Int](None)
+ var intState = GroupStateImpl.createForStreaming[Int](
+ None, 1000, 1000, NoTimeout, hasTimedOut = false)
intercept[NoSuchElementException] {
intState.get
}
assert(intState.getOption === None)
- intState = new GroupStateImpl[Int](Some(10))
+ intState = GroupStateImpl.createForStreaming[Int](
+ Some(10), 1000, 1000, NoTimeout, hasTimedOut = false)
assert(intState.get == 10)
intState.update(0)
assert(intState.get == 0)
@@ -210,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val beforeTimeoutThreshold = 999
val afterTimeoutThreshold = 1001
-
// Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout
for (priorState <- Seq(None, Some(0))) {
val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state"
@@ -318,6 +402,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
+ // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(),
+ // Try to remove these cases in the future
+ for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
+ val testName =
+ if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout"
+ testStateUpdateWithData(
+ s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed",
+ stateUpdates = state => { state.setTimeoutDuration(5000) },
+ timeoutConf = ProcessingTimeTimeout,
+ priorState = None,
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedException = classOf[IllegalStateException])
+
+ testStateUpdateWithData(
+ s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed",
+ stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) },
+ timeoutConf = ProcessingTimeTimeout,
+ priorState = Some(5),
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedException = classOf[IllegalStateException])
+
+ testStateUpdateWithData(
+ s"EventTimeTimeout - $testName - setting timeout without init state not allowed",
+ stateUpdates = state => { state.setTimeoutTimestamp(10000) },
+ timeoutConf = EventTimeTimeout,
+ priorState = None,
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedException = classOf[IllegalStateException])
+
+ testStateUpdateWithData(
+ s"EventTimeTimeout - $testName - setting timeout with state removal not allowed",
+ stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) },
+ timeoutConf = EventTimeTimeout,
+ priorState = Some(5),
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedException = classOf[IllegalStateException])
+ }
+
// Tests for StateStoreUpdater.updateStateForTimedOutKeys()
val preTimeoutState = Some(5)
for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) {
@@ -558,7 +680,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
.flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc)
testStream(result, Update)(
- StartStream(ProcessingTime("1 second"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "a"),
AdvanceManualClock(1 * 1000),
CheckLastBatch(("a", "1")),
@@ -589,7 +711,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
)
}
- test("flatMapGroupsWithState - streaming with event time timeout") {
+ test("flatMapGroupsWithState - streaming with event time timeout + watermark") {
// Function to maintain the max event time
// Returns the max event time in the state, or -1 if the state was removed by timeout
val stateFunc = (
@@ -623,7 +745,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
.flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc)
testStream(result, Update)(
- StartStream(ProcessingTime("1 second")),
+ StartStream(Trigger.ProcessingTime("1 second")),
AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ...
CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s
AddData(inputData, ("a", 4)), // Add data older than watermark for "a"
@@ -677,15 +799,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
test("mapGroupsWithState - batch") {
- val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
+ // Test the following
+ // - no initial state
+ // - timeouts operations work, does not throw any error [SPARK-20792]
+ // - works with primitive state type
+ val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ state.setTimeoutTimestamp(0, "1 hour")
+ state.update(10)
(key, values.size)
}
checkAnswer(
spark.createDataset(Seq("a", "a", "b"))
.groupByKey(x => x)
- .mapGroupsWithState(stateFunc)
+ .mapGroupsWithState(EventTimeTimeout)(stateFunc)
.toDF,
spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
}
@@ -761,6 +889,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
assert(e.getMessage === "The output mode of function should be append or update")
}
+ def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
+ test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+ val stateFunc =
+ (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => {
+ if (state.hasTimedOut) {
+ state.remove()
+ Iterator((key, "-1"))
+ } else {
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ state.update(RunningCount(count))
+ state.setTimeoutDuration("10 seconds")
+ Iterator((key, count.toString))
+ }
+ }
+
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[(String, Long)]
+ val result =
+ inputData.toDF().toDF("key", "time")
+ .selectExpr("key", "cast(time as timestamp) as timestamp")
+ .withWatermark("timestamp", "10 second")
+ .as[(String, Long)]
+ .groupByKey(x => x._1)
+ .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc)
+
+ testStream(result, Update)(
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+ AddData(inputData, ("a", 1L)),
+ AdvanceManualClock(1 * 1000),
+ CheckLastBatch(("a", "1"))
+ )
+ }
+ }
+ testWithTimeout(NoTimeout)
+ testWithTimeout(ProcessingTimeTimeout)
+
def testStateUpdateWithData(
testName: String,
stateUpdates: GroupState[Int] => Unit,
@@ -768,7 +934,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
priorState: Option[Int],
priorTimeoutTimestamp: Long = NO_TIMESTAMP,
expectedState: Option[Int] = None,
- expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
+ expectedTimeoutTimestamp: Long = NO_TIMESTAMP,
+ expectedException: Class[_ <: Exception] = null): Unit = {
if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) {
return // there can be no prior timestamp, when there is no prior state
@@ -782,7 +949,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
testStateUpdate(
testTimeoutUpdates = false, mapGroupsFunc, timeoutConf,
- priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp)
+ priorState, priorTimeoutTimestamp,
+ expectedState, expectedTimeoutTimestamp, expectedException)
}
}
@@ -801,9 +969,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
stateUpdates(state)
Iterator.empty
}
+
testStateUpdate(
testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf,
- preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp)
+ preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp, null)
}
}
@@ -814,7 +983,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
priorState: Option[Int],
priorTimeoutTimestamp: Long,
expectedState: Option[Int],
- expectedTimeoutTimestamp: Long): Unit = {
+ expectedTimeoutTimestamp: Long,
+ expectedException: Class[_ <: Exception]): Unit = {
val store = newStateStore()
val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
@@ -829,22 +999,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
// Call updating function to update state store
- val returnedIter = if (testTimeoutUpdates) {
- updater.updateStateForTimedOutKeys()
- } else {
- updater.updateStateForKeysWithData(Iterator(key))
+ def callFunction() = {
+ val returnedIter = if (testTimeoutUpdates) {
+ updater.updateStateForTimedOutKeys()
+ } else {
+ updater.updateStateForKeysWithData(Iterator(key))
+ }
+ returnedIter.size // consume the iterator to force state updates
}
- returnedIter.size // consumer the iterator to force state updates
-
- // Verify updated state in store
- val updatedStateRow = store.get(key)
- assert(
- updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
- "final state not as expected")
- if (updatedStateRow.nonEmpty) {
+ if (expectedException != null) {
+ // Call function and verify the exception type
+ val e = intercept[Exception] { callFunction() }
+ assert(e.getClass === expectedException, "Exception thrown but of the wrong type")
+ } else {
+ // Call function to update and verify updated state in store
+ callFunction()
+ val updatedStateRow = store.get(key)
assert(
- updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
- "final timeout timestamp not as expected")
+ updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
+ "final state not as expected")
+ if (updatedStateRow.nonEmpty) {
+ assert(
+ updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
+ "final timeout timestamp not as expected")
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 13fe51a557733..1fc062974e185 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -25,6 +25,8 @@ import scala.util.control.ControlThrowable
import org.apache.commons.io.FileUtils
+import org.apache.spark.SparkContext
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
@@ -69,6 +71,27 @@ class StreamSuite extends StreamTest {
CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four")))
}
+ test("SPARK-20432: union one stream with itself") {
+ val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a")
+ val unioned = df.union(df)
+ withTempDir { outputDir =>
+ withTempDir { checkpointDir =>
+ val query =
+ unioned
+ .writeStream.format("parquet")
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .start(outputDir.getAbsolutePath)
+ try {
+ query.processAllAvailable()
+ val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long]
+ checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*)
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+
test("union two streams") {
val inputData1 = MemoryStream[Int]
val inputData2 = MemoryStream[Int]
@@ -120,6 +143,33 @@ class StreamSuite extends StreamTest {
assertDF(df)
}
+ test("Within the same streaming query, one StreamingRelation should only be transformed to one " +
+ "StreamingExecutionRelation") {
+ val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load()
+ var query: StreamExecution = null
+ try {
+ query =
+ df.union(df)
+ .writeStream
+ .format("memory")
+ .queryName("memory")
+ .start()
+ .asInstanceOf[StreamingQueryWrapper]
+ .streamingQuery
+ query.awaitInitialization(streamingTimeout.toMillis)
+ val executionRelations =
+ query
+ .logicalPlan
+ .collect { case ser: StreamingExecutionRelation => ser }
+ assert(executionRelations.size === 2)
+ assert(executionRelations.distinct.size === 1)
+ } finally {
+ if (query != null) {
+ query.stop()
+ }
+ }
+ }
+
test("unsupported queries") {
val streamInput = MemoryStream[Int]
val batchInput = Seq(1, 2, 3).toDS()
@@ -500,6 +550,70 @@ class StreamSuite extends StreamTest {
}
}
}
+
+ test("calling stop() on a query cancels related jobs") {
+ val input = MemoryStream[Int]
+ val query = input
+ .toDS()
+ .map { i =>
+ while (!org.apache.spark.TaskContext.get().isInterrupted()) {
+ // keep looping till interrupted by query.stop()
+ Thread.sleep(100)
+ }
+ i
+ }
+ .writeStream
+ .format("console")
+ .start()
+
+ input.addData(1)
+ // wait for jobs to start
+ eventually(timeout(streamingTimeout)) {
+ assert(sparkContext.statusTracker.getActiveJobIds().nonEmpty)
+ }
+
+ query.stop()
+ // make sure jobs are stopped
+ eventually(timeout(streamingTimeout)) {
+ assert(sparkContext.statusTracker.getActiveJobIds().isEmpty)
+ }
+ }
+
+ test("batch id is updated correctly in the job description") {
+ val queryName = "memStream"
+ @volatile var jobDescription: String = null
+ def assertDescContainsQueryNameAnd(batch: Integer): Unit = {
+ // wait for listener event to be processed
+ spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis)
+ assert(jobDescription.contains(queryName) && jobDescription.contains(s"batch = $batch"))
+ }
+
+ spark.sparkContext.addSparkListener(new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ jobDescription = jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)
+ }
+ })
+
+ val input = MemoryStream[Int]
+ val query = input
+ .toDS()
+ .map(_ + 1)
+ .writeStream
+ .format("memory")
+ .queryName(queryName)
+ .start()
+
+ input.addData(1)
+ query.processAllAvailable()
+ assertDescContainsQueryNameAnd(batch = 0)
+ input.addData(2, 3)
+ query.processAllAvailable()
+ assertDescContainsQueryNameAnd(batch = 1)
+ input.addData(4)
+ query.processAllAvailable()
+ assertDescContainsQueryNameAnd(batch = 2)
+ query.stop()
+ }
}
abstract class FakeSource extends StreamSourceProvider {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 5bc36dd30f6d1..2a4039cc5831a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -172,8 +172,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
*
* @param isFatalError if this is a fatal error. If so, the error should also be caught by
* UncaughtExceptionHandler.
+ * @param assertFailure a function to verify the error.
*/
case class ExpectFailure[T <: Throwable : ClassTag](
+ assertFailure: Throwable => Unit = _ => {},
isFatalError: Boolean = false) extends StreamAction {
val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
override def toString(): String =
@@ -455,6 +457,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause")
streamThreadDeathCause = null
}
+ ef.assertFailure(exception.getCause)
} catch {
case _: InterruptedException =>
case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index f796a4cb4a398..b6e82b621c8cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
)
}
+ test("count distinct") {
+ val inputData = MemoryStream[(Int, Seq[Int])]
+
+ val aggregated =
+ inputData.toDF()
+ .select($"*", explode($"_2") as 'value)
+ .groupBy($"_1")
+ .agg(size(collect_set($"value")))
+ .as[(Int, Int)]
+
+ testStream(aggregated, Update)(
+ AddData(inputData, (1, Seq(1, 2))),
+ CheckLastBatch((1, 2))
+ )
+ }
+
test("simple count, complete mode") {
val inputData = MemoryStream[Int]
@@ -251,7 +267,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
.where('value >= current_timestamp().cast("long") - 10L)
testStream(aggregated, Complete)(
- StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock),
// advance clock to 10 seconds, all keys retained
AddData(inputData, 0L, 5L, 5L, 10L),
@@ -278,7 +294,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
clock.advance(60 * 1000L)
true
},
- StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock),
// The commit log blown, causing the last batch to re-run
CheckLastBatch((20L, 1), (85L, 1)),
AssertOnQuery { q =>
@@ -306,7 +322,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
.where($"value".cast("date") >= date_sub(current_date(), 10))
.select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)")
testStream(aggregated, Complete)(
- StartStream(ProcessingTime("10 day"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock),
// advance clock to 10 days, should retain all keys
AddData(inputData, 0L, 5L, 5L, 10L),
AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
@@ -330,7 +346,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60)
true
},
- StartStream(ProcessingTime("10 day"), triggerClock = clock),
+ StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock),
// Commit log blown, causing a re-run of the last batch
CheckLastBatch((20L, 1), (85L, 1)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index b8a694c177310..59c6a6fade175 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -21,6 +21,7 @@ import java.util.UUID
import scala.collection.mutable
import scala.concurrent.duration._
+import scala.language.reflectiveCalls
import org.scalactic.TolerantNumerics
import org.scalatest.concurrent.AsyncAssertions.Waiter
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
index b49efa6890236..2986b7f1eecfb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
@@ -78,9 +78,9 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter {
eventually(Timeout(streamingTimeout)) {
require(!q2.isActive)
require(q2.exception.isDefined)
+ assert(spark.streams.get(q2.id) === null)
+ assert(spark.streams.active.toSet === Set(q3))
}
- assert(spark.streams.get(q2.id) === null)
- assert(spark.streams.active.toSet === Set(q3))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index b69536ed37463..9e65aa8d09e1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -425,6 +425,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
}
+ test("SPARK-22975: MetricsReporter defaults when there was no progress reported") {
+ withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") {
+ BlockingSource.latch = new CountDownLatch(1)
+ withTempDir { tempDir =>
+ val sq = spark.readStream
+ .format("org.apache.spark.sql.streaming.util.BlockingSource")
+ .load()
+ .writeStream
+ .format("org.apache.spark.sql.streaming.util.BlockingSource")
+ .option("checkpointLocation", tempDir.toString)
+ .start()
+ .asInstanceOf[StreamingQueryWrapper]
+ .streamingQuery
+
+ val gauges = sq.streamMetrics.metricRegistry.getGauges
+ assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0)
+ assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0)
+ assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0)
+ sq.stop()
+ }
+ }
+ }
+
test("input row calculation with mixed batch and streaming sources") {
val streamingTriggerDF = spark.createDataset(1 to 10).toDF
val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value")
@@ -510,22 +533,22 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
.start()
}
- val input = MemoryStream[Int]
- val q1 = startQuery(input.toDS, "stream_serializable_test_1")
- val q2 = startQuery(input.toDS.map { i =>
+ val input = MemoryStream[Int] :: MemoryStream[Int] :: MemoryStream[Int] :: Nil
+ val q1 = startQuery(input(0).toDS, "stream_serializable_test_1")
+ val q2 = startQuery(input(1).toDS.map { i =>
// Emulate that `StreamingQuery` get captured with normal usage unintentionally.
// It should not fail the query.
q1
i
}, "stream_serializable_test_2")
- val q3 = startQuery(input.toDS.map { i =>
+ val q3 = startQuery(input(2).toDS.map { i =>
// Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear
// error message.
q1.explain()
i
}, "stream_serializable_test_3")
try {
- input.addData(1)
+ input.foreach(_.addData(1))
// q2 should not fail since it doesn't use `q1` in the closure
q2.processAllAvailable()
@@ -613,6 +636,18 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
}
+ test("processAllAvailable should not block forever when a query is stopped") {
+ val input = MemoryStream[Int]
+ input.addData(1)
+ val query = input.toDF().writeStream
+ .trigger(Trigger.Once())
+ .format("console")
+ .start()
+ failAfter(streamingTimeout) {
+ query.processAllAvailable()
+ }
+ }
+
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index dc2506a48ad00..bae9d811f7790 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -641,6 +641,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
test("temp checkpoint dir should be deleted if a query is stopped without errors") {
import testImplicits._
val query = MemoryStream[Int].toDS.writeStream.format("console").start()
+ query.processAllAvailable()
val checkpointDir = new Path(
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.checkpointRoot)
val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf())
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 6a4cc95d36bea..f6d47734d7e83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -20,13 +20,15 @@ package org.apache.spark.sql.test
import java.io.File
import java.net.URI
import java.nio.file.Files
-import java.util.UUID
+import java.util.{Locale, UUID}
+import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
import org.scalatest.BeforeAndAfterAll
+import org.scalatest.concurrent.Eventually
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
@@ -49,7 +51,7 @@ import org.apache.spark.util.{UninterruptibleThread, Utils}
* prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
*/
private[sql] trait SQLTestUtils
- extends SparkFunSuite
+ extends SparkFunSuite with Eventually
with BeforeAndAfterAll
with SQLTestData { self =>
@@ -138,6 +140,15 @@ private[sql] trait SQLTestUtils
}
}
+ /**
+ * Waits for all tasks on all executors to be finished.
+ */
+ protected def waitForTasksToFinish(): Unit = {
+ eventually(timeout(10.seconds)) {
+ assert(spark.sparkContext.statusTracker
+ .getExecutorInfos.map(_.numRunningTasks()).sum == 0)
+ }
+ }
/**
* Creates a temporary directory, which is then passed to `f` and will be deleted after `f`
* returns.
@@ -146,7 +157,11 @@ private[sql] trait SQLTestUtils
*/
protected def withTempDir(f: File => Unit): Unit = {
val dir = Utils.createTempDir().getCanonicalFile
- try f(dir) finally Utils.deleteRecursively(dir)
+ try f(dir) finally {
+ // wait for all tasks to finish before deleting files
+ waitForTasksToFinish()
+ Utils.deleteRecursively(dir)
+ }
}
/**
@@ -222,12 +237,39 @@ private[sql] trait SQLTestUtils
try f(dbName) finally {
if (spark.catalog.currentDatabase == dbName) {
- spark.sql(s"USE ${DEFAULT_DATABASE}")
+ spark.sql(s"USE $DEFAULT_DATABASE")
}
spark.sql(s"DROP DATABASE $dbName CASCADE")
}
}
+ /**
+ * Drops database `dbName` after calling `f`.
+ */
+ protected def withDatabase(dbNames: String*)(f: => Unit): Unit = {
+ try f finally {
+ dbNames.foreach { name =>
+ spark.sql(s"DROP DATABASE IF EXISTS $name")
+ }
+ spark.sql(s"USE $DEFAULT_DATABASE")
+ }
+ }
+
+ /**
+ * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM
+ * after `f` returns.
+ */
+ protected def withLocale(language: String)(f: => Unit): Unit = {
+ val originalLocale = Locale.getDefault
+ try {
+ // Add Locale setting
+ Locale.setDefault(new Locale(language))
+ f
+ } finally {
+ Locale.setDefault(originalLocale)
+ }
+ }
+
/**
* Activates database `db` before executing `f`, then switches back to `default` database after
* `f` returns.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index e122b39f6fc40..7cea4c02155ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -17,19 +17,22 @@
package org.apache.spark.sql.test
+import scala.concurrent.duration._
+
import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.sql.internal.SQLConf
-
/**
* Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
*/
-trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach {
+trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually {
- protected val sparkConf = new SparkConf()
+ protected def sparkConf = {
+ new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+ }
/**
* The [[TestSparkSession]] to use for all tests in this suite.
@@ -50,8 +53,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach {
protected implicit def sqlContext: SQLContext = _spark.sqlContext
protected def createSparkSession: TestSparkSession = {
- new TestSparkSession(
- sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
+ new TestSparkSession(sparkConf)
}
/**
@@ -72,6 +74,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach {
protected override def afterAll(): Unit = {
super.afterAll()
if (_spark != null) {
+ _spark.sessionState.catalog.reset()
_spark.stop()
_spark = null
}
@@ -84,6 +87,10 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach {
protected override def afterEach(): Unit = {
super.afterEach()
- DebugFilesystem.assertNoOpenStreams()
+ // files can be closed from other threads, so wait a bit
+ // normally this doesn't take more than 1s
+ eventually(timeout(10.seconds)) {
+ DebugFilesystem.assertNoOpenStreams()
+ }
}
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 9c879218ddc0d..2fe49aca182a4 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../../pom.xml
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index ff3784cab9e26..1d1074a2a7387 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -253,6 +253,8 @@ private[hive] class SparkExecuteStatementOperation(
return
} else {
setState(OperationState.ERROR)
+ HiveThriftServer2.listener.onStatementError(
+ statementId, e.getMessage, SparkUtils.exceptionString(e))
throw e
}
// Actually do need to catch Throwable as some failures don't inherit from Exception and
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
index f39e9dcd3a5bb..38b8605745752 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
@@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
val content =
@@ -197,4 +198,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true)
}
}
-
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 0f249d7d59351..64e0a86699ea6 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../../pom.xml
@@ -162,6 +162,10 @@
org.apache.thrift
libfb303
+
+ org.apache.derby
+ derby
+
org.scalacheck
scalacheck_${scala.binary.version}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 8b0fdf49cefab..1c26d9822e120 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* should interpret these special data source properties and restore the original table metadata
* before returning it.
*/
- private def getRawTable(db: String, table: String): CatalogTable = withClient {
+ private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient {
client.getTable(db, table)
}
@@ -137,17 +137,34 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}
}
+ /**
+ * Checks the validity of data column names. Hive metastore disallows the table to use comma in
+ * data column names. Partition columns do not have such a restriction. Views do not have such
+ * a restriction.
+ */
+ private def verifyDataSchema(
+ tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = {
+ if (tableType != VIEW) {
+ dataSchema.map(_.name).foreach { colName =>
+ if (colName.contains(",")) {
+ throw new AnalysisException("Cannot create a table having a column whose name contains " +
+ s"commas in Hive metastore. Table: $tableName; Column: $colName")
+ }
+ }
+ }
+ }
+
// --------------------------------------------------------------------------
// Databases
// --------------------------------------------------------------------------
- override def createDatabase(
+ override protected def doCreateDatabase(
dbDefinition: CatalogDatabase,
ignoreIfExists: Boolean): Unit = withClient {
client.createDatabase(dbDefinition, ignoreIfExists)
}
- override def dropDatabase(
+ override protected def doDropDatabase(
db: String,
ignoreIfNotExists: Boolean,
cascade: Boolean): Unit = withClient {
@@ -194,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// Tables
// --------------------------------------------------------------------------
- override def createTable(
+ override protected def doCreateTable(
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = withClient {
assert(tableDefinition.identifier.database.isDefined)
@@ -202,44 +219,43 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
val table = tableDefinition.identifier.table
requireDbExists(db)
verifyTableProperties(tableDefinition)
+ verifyDataSchema(
+ tableDefinition.identifier, tableDefinition.tableType, tableDefinition.dataSchema)
if (tableExists(db, table) && !ignoreIfExists) {
throw new TableAlreadyExistsException(db = db, table = table)
}
- if (tableDefinition.tableType == VIEW) {
- client.createTable(tableDefinition, ignoreIfExists)
+ // Ideally we should not create a managed table with location, but Hive serde table can
+ // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have
+ // to create the table directory and write out data before we create this table, to avoid
+ // exposing a partial written table.
+ val needDefaultTableLocation = tableDefinition.tableType == MANAGED &&
+ tableDefinition.storage.locationUri.isEmpty
+
+ val tableLocation = if (needDefaultTableLocation) {
+ Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier)))
} else {
- // Ideally we should not create a managed table with location, but Hive serde table can
- // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have
- // to create the table directory and write out data before we create this table, to avoid
- // exposing a partial written table.
- val needDefaultTableLocation = tableDefinition.tableType == MANAGED &&
- tableDefinition.storage.locationUri.isEmpty
-
- val tableLocation = if (needDefaultTableLocation) {
- Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier)))
- } else {
- tableDefinition.storage.locationUri
- }
+ tableDefinition.storage.locationUri
+ }
- if (DDLUtils.isHiveTable(tableDefinition)) {
- val tableWithDataSourceProps = tableDefinition.copy(
- // We can't leave `locationUri` empty and count on Hive metastore to set a default table
- // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default
- // table location for tables in default database, while we expect to use the location of
- // default database.
- storage = tableDefinition.storage.copy(locationUri = tableLocation),
- // Here we follow data source tables and put table metadata like table schema, partition
- // columns etc. in table properties, so that we can work around the Hive metastore issue
- // about not case preserving and make Hive serde table support mixed-case column names.
- properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition))
- client.createTable(tableWithDataSourceProps, ignoreIfExists)
- } else {
- createDataSourceTable(
- tableDefinition.withNewStorage(locationUri = tableLocation),
- ignoreIfExists)
- }
+ if (DDLUtils.isDatasourceTable(tableDefinition)) {
+ createDataSourceTable(
+ tableDefinition.withNewStorage(locationUri = tableLocation),
+ ignoreIfExists)
+ } else {
+ val tableWithDataSourceProps = tableDefinition.copy(
+ // We can't leave `locationUri` empty and count on Hive metastore to set a default table
+ // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default
+ // table location for tables in default database, while we expect to use the location of
+ // default database.
+ storage = tableDefinition.storage.copy(locationUri = tableLocation),
+ // Here we follow data source tables and put table metadata like table schema, partition
+ // columns etc. in table properties, so that we can work around the Hive metastore issue
+ // about not case preserving and make Hive serde table and view support mixed-case column
+ // names.
+ properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition))
+ client.createTable(tableWithDataSourceProps, ignoreIfExists)
}
}
@@ -281,7 +297,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
storage = table.storage.copy(
locationUri = None,
properties = storagePropsWithLocation),
- schema = table.partitionSchema,
+ schema = StructType(EMPTY_DATA_SCHEMA ++ table.partitionSchema),
bucketSpec = None,
properties = table.properties ++ tableProperties)
}
@@ -298,6 +314,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
None
}
+ // TODO: empty data schema is not hive compatible, we only do it to keep behavior as it was
+ // because previously we generate the special empty schema in `HiveClient`. Remove this in
+ // Spark 2.3.
+ val schema = if (table.dataSchema.isEmpty) {
+ StructType(EMPTY_DATA_SCHEMA ++ table.partitionSchema)
+ } else {
+ table.schema
+ }
+
table.copy(
storage = table.storage.copy(
locationUri = location,
@@ -306,6 +331,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
serde = serde.serde,
properties = storagePropsWithLocation
),
+ schema = schema,
properties = table.properties ++ tableProperties)
}
@@ -372,6 +398,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* can be used as table properties later.
*/
private def tableMetaToTableProps(table: CatalogTable): mutable.Map[String, String] = {
+ tableMetaToTableProps(table, table.schema)
+ }
+
+ private def tableMetaToTableProps(
+ table: CatalogTable,
+ schema: StructType): mutable.Map[String, String] = {
val partitionColumns = table.partitionColumnNames
val bucketSpec = table.bucketSpec
@@ -380,7 +412,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// property. In this case, we split the JSON string and store each part as a separate table
// property.
val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)
- val schemaJsonString = table.schema.json
+ val schemaJsonString = schema.json
// Split the JSON string.
val parts = schemaJsonString.grouped(threshold).toSeq
properties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString)
@@ -456,7 +488,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}
}
- override def dropTable(
+ override protected def doDropTable(
db: String,
table: String,
ignoreIfNotExists: Boolean,
@@ -465,7 +497,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
client.dropTable(db, table, ignoreIfNotExists, purge)
}
- override def renameTable(db: String, oldName: String, newName: String): Unit = withClient {
+ override protected def doRenameTable(
+ db: String,
+ oldName: String,
+ newName: String): Unit = withClient {
val rawTable = getRawTable(db, oldName)
// Note that Hive serde tables don't use path option in storage properties to store the value
@@ -607,22 +642,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}
}
- override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient {
+ override def alterTableDataSchema(
+ db: String, table: String, newDataSchema: StructType): Unit = withClient {
requireTableExists(db, table)
- val rawTable = getRawTable(db, table)
- val withNewSchema = rawTable.copy(schema = schema)
- // Add table metadata such as table schema, partition columns, etc. to table properties.
- val updatedTable = withNewSchema.copy(
- properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema))
- try {
- client.alterTable(updatedTable)
- } catch {
- case NonFatal(e) =>
- val warningMessage =
- s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " +
- "compatible way. Updating Hive metastore in Spark SQL specific format."
- logWarning(warningMessage, e)
- client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema))
+ val oldTable = getTable(db, table)
+ verifyDataSchema(oldTable.identifier, oldTable.tableType, newDataSchema)
+ val schemaProps =
+ tableMetaToTableProps(oldTable, StructType(newDataSchema ++ oldTable.partitionSchema)).toMap
+
+ if (isDatasourceTable(oldTable)) {
+ // For data source tables, first try to write it with the schema set; if that does not work,
+ // try again with updated properties and the partition schema. This is a simplified version of
+ // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive
+ // (for example, the schema does not match the data source schema, or does not match the
+ // storage descriptor).
+ try {
+ client.alterTableDataSchema(db, table, newDataSchema, schemaProps)
+ } catch {
+ case NonFatal(e) =>
+ val warningMessage =
+ s"Could not alter schema of table ${oldTable.identifier.quotedString} in a Hive " +
+ "compatible way. Updating Hive metastore in Spark SQL specific format."
+ logWarning(warningMessage, e)
+ client.alterTableDataSchema(db, table, EMPTY_DATA_SCHEMA, schemaProps)
+ }
+ } else {
+ client.alterTableDataSchema(db, table, newDataSchema, schemaProps)
}
}
@@ -630,10 +675,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
restoreTableMetadata(getRawTable(db, table))
}
- override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient {
- client.getTableOption(db, table).map(restoreTableMetadata)
- }
-
/**
* Restores table metadata from the table properties. This method is kind of a opposite version
* of [[createTable]].
@@ -648,16 +689,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
var table = inputTable
- if (table.tableType != VIEW) {
- table.properties.get(DATASOURCE_PROVIDER) match {
- // No provider in table properties, which means this is a Hive serde table.
- case None =>
- table = restoreHiveSerdeTable(table)
+ table.properties.get(DATASOURCE_PROVIDER) match {
+ case None if table.tableType == VIEW =>
+ // If this is a view created by Spark 2.2 or higher versions, we should restore its schema
+ // from table properties.
+ if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) {
+ table = table.copy(schema = getSchemaFromTableProperties(table))
+ }
- // This is a regular data source table.
- case Some(provider) =>
- table = restoreDataSourceTable(table, provider)
- }
+ // No provider in table properties, which means this is a Hive serde table.
+ case None =>
+ table = restoreHiveSerdeTable(table)
+
+ // This is a regular data source table.
+ case Some(provider) =>
+ table = restoreDataSourceTable(table, provider)
}
// construct Spark's statistics from information in Hive metastore
@@ -696,6 +742,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) })
}
+ // Reorder table schema to put partition columns at the end. Before Spark 2.2, the partition
+ // columns are not put at the end of schema. We need to reorder it when reading the schema
+ // from the table properties.
+ private def reorderSchema(schema: StructType, partColumnNames: Seq[String]): StructType = {
+ val partitionFields = partColumnNames.map { partCol =>
+ schema.find(_.name == partCol).getOrElse {
+ throw new AnalysisException("The metadata is corrupted. Unable to find the " +
+ s"partition column names from the schema. schema: ${schema.catalogString}. " +
+ s"Partition columns: ${partColumnNames.mkString("[", ", ", "]")}")
+ }
+ }
+ StructType(schema.filterNot(partitionFields.contains) ++ partitionFields)
+ }
+
private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = {
val hiveTable = table.copy(
provider = Some(DDLUtils.HIVE_PROVIDER),
@@ -705,10 +765,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// schema from table properties.
if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) {
val schemaFromTableProps = getSchemaFromTableProperties(table)
- if (DataType.equalsIgnoreCaseAndNullability(schemaFromTableProps, table.schema)) {
+ val partColumnNames = getPartitionColumnsFromTableProperties(table)
+ val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames)
+
+ if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema)) {
hiveTable.copy(
- schema = schemaFromTableProps,
- partitionColumnNames = getPartitionColumnsFromTableProperties(table),
+ schema = reorderedSchema,
+ partitionColumnNames = partColumnNames,
bucketSpec = getBucketSpecFromTableProperties(table))
} else {
// Hive metastore may change the table schema, e.g. schema inference. If the table
@@ -738,11 +801,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}
val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER)
+ val schemaFromTableProps = getSchemaFromTableProperties(table)
+ val partColumnNames = getPartitionColumnsFromTableProperties(table)
+ val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames)
+
table.copy(
provider = Some(provider),
storage = storageWithLocation,
- schema = getSchemaFromTableProperties(table),
- partitionColumnNames = getPartitionColumnsFromTableProperties(table),
+ schema = reorderedSchema,
+ partitionColumnNames = partColumnNames,
bucketSpec = getBucketSpecFromTableProperties(table),
tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG))
}
@@ -1030,9 +1097,19 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
table: String,
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient {
val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))
- client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part =>
+ val res = client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
+
+ partialSpec match {
+ // This might be a bug of Hive: When the partition value inside the partial partition spec
+ // contains dot, and we ask Hive to list partitions w.r.t. the partial partition spec, Hive
+ // treats dot as matching any single character and may return more partitions than we
+ // expected. Here we do an extra filter to drop unexpected partitions.
+ case Some(spec) if spec.exists(_._2.contains(".")) =>
+ res.filter(p => isPartialPartitionSpec(spec, p.spec))
+ case _ => res
+ }
}
override def listPartitionsByFilter(
@@ -1056,7 +1133,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// Functions
// --------------------------------------------------------------------------
- override def createFunction(
+ override protected def doCreateFunction(
db: String,
funcDefinition: CatalogFunction): Unit = withClient {
requireDbExists(db)
@@ -1069,12 +1146,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier))
}
- override def dropFunction(db: String, name: String): Unit = withClient {
+ override protected def doDropFunction(db: String, name: String): Unit = withClient {
requireFunctionExists(db, name)
client.dropFunction(db, name)
}
- override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient {
+ override protected def doRenameFunction(
+ db: String,
+ oldName: String,
+ newName: String): Unit = withClient {
requireFunctionExists(db, oldName)
requireFunctionNotExists(db, newName)
client.renameFunction(db, oldName, newName)
@@ -1123,6 +1203,15 @@ object HiveExternalCatalog {
val TABLE_PARTITION_PROVIDER_CATALOG = "catalog"
val TABLE_PARTITION_PROVIDER_FILESYSTEM = "filesystem"
+ // When storing data source tables in hive metastore, we need to set data schema to empty if the
+ // schema is hive-incompatible. However we need a hack to preserve existing behavior. Before
+ // Spark 2.0, we do not set a default serde here (this was done in Hive), and so if the user
+ // provides an empty schema Hive would automatically populate the schema with a single field
+ // "col". However, after SPARK-14388, we set the default serde to LazySimpleSerde so this
+ // implicit behavior no longer happens. Therefore, we need to do it in Spark ourselves.
+ val EMPTY_DATA_SCHEMA = new StructType()
+ .add("col", "array", nullable = true, comment = "from deserializer")
+
/**
* Returns the fully qualified name used in table properties for a particular column stat.
* For example, for column "mycol", and "min" stat, this should return
@@ -1193,4 +1282,14 @@ object HiveExternalCatalog {
getColumnNamesByType(metadata.properties, "sort", "sorting columns"))
}
}
+
+ /**
+ * Detects a data source table. This checks both the table provider and the table properties,
+ * unlike DDLUtils which just checks the former.
+ */
+ private[spark] def isDatasourceTable(table: CatalogTable): Boolean = {
+ val provider = table.provider.orElse(table.properties.get(DATASOURCE_PROVIDER))
+ provider.isDefined && provider != Some(DDLUtils.HIVE_PROVIDER)
+ }
+
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 6b98066cb76c8..f858dd9319c8a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.types._
private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging {
// these are def_s and not val/lazy val since the latter would introduce circular references
private def sessionState = sparkSession.sessionState
- private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
+ private def catalogProxy = sparkSession.sessionState.catalog
import HiveMetastoreCatalog._
/** These locks guard against multiple attempts to instantiate a table, which wastes memory. */
@@ -61,7 +61,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
val key = QualifiedTableName(
table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase,
table.table.toLowerCase)
- tableRelationCache.getIfPresent(key)
+ catalogProxy.getCachedTable(key)
}
private def getCached(
@@ -71,7 +71,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
expectedFileFormat: Class[_ <: FileFormat],
partitionSchema: Option[StructType]): Option[LogicalRelation] = {
- tableRelationCache.getIfPresent(tableIdentifier) match {
+ catalogProxy.getCachedTable(tableIdentifier) match {
case null => None // Cache miss
case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) =>
val cachedRelationFileFormatClass = relation.fileFormat.getClass
@@ -92,27 +92,27 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
Some(logical)
} else {
// If the cached relation is not updated, we invalidate it right away.
- tableRelationCache.invalidate(tableIdentifier)
+ catalogProxy.invalidateCachedTable(tableIdentifier)
None
}
case _ =>
logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " +
s"However, we are getting a ${relation.fileFormat} from the metastore cache. " +
"This cached entry will be invalidated.")
- tableRelationCache.invalidate(tableIdentifier)
+ catalogProxy.invalidateCachedTable(tableIdentifier)
None
}
case other =>
logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " +
s"However, we are getting a $other from the metastore cache. " +
"This cached entry will be invalidated.")
- tableRelationCache.invalidate(tableIdentifier)
+ catalogProxy.invalidateCachedTable(tableIdentifier)
None
}
}
def convertToLogicalRelation(
- relation: CatalogRelation,
+ relation: HiveTableRelation,
options: Map[String, String],
fileFormatClass: Class[_ <: FileFormat],
fileType: String): LogicalRelation = {
@@ -164,19 +164,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
}
- val (dataSchema, updatedTable) =
- inferIfNeeded(relation, options, fileFormat, Option(fileIndex))
+ val updatedTable = inferIfNeeded(relation, options, fileFormat, Option(fileIndex))
val fsRelation = HadoopFsRelation(
location = fileIndex,
partitionSchema = partitionSchema,
- dataSchema = dataSchema,
+ dataSchema = updatedTable.dataSchema,
// We don't support hive bucketed tables, only ones we write out.
bucketSpec = None,
fileFormat = fileFormat,
options = options)(sparkSession = sparkSession)
val created = LogicalRelation(fsRelation, updatedTable)
- tableRelationCache.put(tableIdentifier, created)
+ catalogProxy.cacheTable(tableIdentifier, created)
created
}
@@ -192,27 +191,27 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
fileFormatClass,
None)
val logicalRelation = cached.getOrElse {
- val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat)
+ val updatedTable = inferIfNeeded(relation, options, fileFormat)
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
- userSpecifiedSchema = Option(dataSchema),
+ userSpecifiedSchema = Option(updatedTable.dataSchema),
// We don't support hive bucketed tables, only ones we write out.
bucketSpec = None,
options = options,
className = fileType).resolveRelation(),
table = updatedTable)
- tableRelationCache.put(tableIdentifier, created)
+ catalogProxy.cacheTable(tableIdentifier, created)
created
}
logicalRelation
})
}
- // The inferred schema may have different filed names as the table schema, we should respect
+ // The inferred schema may have different field names as the table schema, we should respect
// it, but also respect the exprId in table relation output.
assert(result.output.length == relation.output.length &&
result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType })
@@ -223,10 +222,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
private def inferIfNeeded(
- relation: CatalogRelation,
+ relation: HiveTableRelation,
options: Map[String, String],
fileFormat: FileFormat,
- fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = {
+ fileIndexOpt: Option[FileIndex] = None): CatalogTable = {
val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode
val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase
val tableName = relation.tableMeta.identifier.unquotedString
@@ -243,28 +242,28 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
sparkSession,
options,
fileIndex.listFiles(Nil, Nil).flatMap(_.files))
- .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _))
+ .map(mergeWithMetastoreSchema(relation.tableMeta.dataSchema, _))
inferredSchema match {
- case Some(schema) =>
+ case Some(dataSchema) =>
if (inferenceMode == INFER_AND_SAVE) {
- updateCatalogSchema(relation.tableMeta.identifier, schema)
+ updateDataSchema(relation.tableMeta.identifier, dataSchema)
}
- (schema, relation.tableMeta.copy(schema = schema))
+ val newSchema = StructType(dataSchema ++ relation.tableMeta.partitionSchema)
+ relation.tableMeta.copy(schema = newSchema)
case None =>
logWarning(s"Unable to infer schema for table $tableName from file format " +
s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.")
- (relation.tableMeta.schema, relation.tableMeta)
+ relation.tableMeta
}
} else {
- (relation.tableMeta.schema, relation.tableMeta)
+ relation.tableMeta
}
}
- private def updateCatalogSchema(identifier: TableIdentifier, schema: StructType): Unit = try {
- val db = identifier.database.get
+ private def updateDataSchema(identifier: TableIdentifier, newDataSchema: StructType): Unit = try {
logInfo(s"Saving case-sensitive schema for table ${identifier.unquotedString}")
- sparkSession.sharedState.externalCatalog.alterTableSchema(db, identifier.table, schema)
+ sparkSession.sessionState.catalog.alterTableDataSchema(identifier, newDataSchema)
} catch {
case NonFatal(ex) =>
logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 377d4f2473c58..6227e780c0409 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -140,7 +140,7 @@ private[sql] class HiveSessionCatalog(
// Hive is case insensitive.
val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT)
if (!hiveFunctions.contains(functionName)) {
- failFunctionLookup(funcName.unquotedString)
+ failFunctionLookup(funcName)
}
// TODO: Remove this fallback path once we implement the list of fallback functions
@@ -148,12 +148,12 @@ private[sql] class HiveSessionCatalog(
val functionInfo = {
try {
Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse(
- failFunctionLookup(funcName.unquotedString))
+ failFunctionLookup(funcName))
} catch {
// If HiveFunctionRegistry.getFunctionInfo throws an exception,
// we are failing to load a Hive builtin function, which means that
// the given function is not a Hive builtin function.
- case NonFatal(e) => failFunctionLookup(funcName.unquotedString)
+ case NonFatal(e) => failFunctionLookup(funcName)
}
}
val className = functionInfo.getFunctionClass.getName
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index 9d3b31f39c0f5..e16c9e46b7723 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -101,7 +101,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ Seq(
FileSourceStrategy,
- DataSourceStrategy,
+ DataSourceStrategy(conf),
SpecialLimits,
InMemoryScans,
HiveTableScans,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 09a5eda6e543f..85a88dff20761 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogStorageFormat, CatalogTable}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation}
@@ -116,7 +116,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] {
class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case relation: CatalogRelation
+ case relation: HiveTableRelation
if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty =>
val table = relation.tableMeta
// TODO: check if this estimate is valid for tables after partition pruning.
@@ -160,9 +160,9 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
*/
object HiveAnalysis extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists)
- if DDLUtils.isHiveTable(relation.tableMeta) =>
- InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists)
+ case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists)
+ if DDLUtils.isHiveTable(r.tableMeta) =>
+ InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists)
case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
@@ -184,21 +184,21 @@ object HiveAnalysis extends Rule[LogicalPlan] {
case class RelationConversions(
conf: SQLConf,
sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] {
- private def isConvertible(relation: CatalogRelation): Boolean = {
+ private def isConvertible(relation: HiveTableRelation): Boolean = {
val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) ||
serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)
}
- private def convert(relation: CatalogRelation): LogicalRelation = {
+ private def convert(relation: HiveTableRelation): LogicalRelation = {
val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
if (serde.contains("parquet")) {
- val options = Map(ParquetOptions.MERGE_SCHEMA ->
+ val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA ->
conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString)
sessionCatalog.metastoreCatalog
.convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet")
} else {
- val options = Map[String, String]()
+ val options = relation.tableMeta.storage.properties
sessionCatalog.metastoreCatalog
.convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc")
}
@@ -207,14 +207,14 @@ case class RelationConversions(
override def apply(plan: LogicalPlan): LogicalPlan = {
plan transformUp {
// Write path
- case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists)
+ case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists)
// Inserting into partitioned table is not supported in Parquet/Orc data source (yet).
- if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
- !r.isPartitioned && isConvertible(r) =>
- InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists)
+ if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
+ !r.isPartitioned && isConvertible(r) =>
+ InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists)
// Read path
- case relation: CatalogRelation
+ case relation: HiveTableRelation
if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) =>
convert(relation)
}
@@ -242,7 +242,7 @@ private[hive] trait HiveStrategies {
*/
object HiveTableScans extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case PhysicalOperation(projectList, predicates, relation: CatalogRelation) =>
+ case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) =>
// Filter out all predicates that only deal with partition keys, these are given to the
// hive table scan operator to be used for partition pruning.
val partitionKeyIds = AttributeSet(relation.partitionCols)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 16c1103dd1ea3..11795ff1795e0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -65,7 +67,7 @@ class HadoopTableReader(
@transient private val tableDesc: TableDesc,
@transient private val sparkSession: SparkSession,
hadoopConf: Configuration)
- extends TableReader with Logging {
+ extends TableReader with CastSupport with Logging {
// Hadoop honors "mapreduce.job.maps" as hint,
// but will ignore when mapreduce.jobtracker.address is "local".
@@ -86,6 +88,8 @@ class HadoopTableReader(
private val _broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ override def conf: SQLConf = sparkSession.sessionState.conf
+
override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
makeRDDForTable(
hiveTable,
@@ -227,7 +231,7 @@ class HadoopTableReader(
def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = {
partitionKeyAttrs.foreach { case (attr, ordinal) =>
val partOrdinal = partitionKeys.indexOf(attr)
- row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
+ row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
}
}
@@ -377,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging {
val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) =>
soi.getStructFieldRef(attr.name) -> ordinal
- }.unzip
+ }.toArray.unzip
/**
* Builds specific unwrappers ahead of time according to object inspector
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index 16a80f9fff452..492a2eaae5710 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.types.StructType
/**
@@ -89,6 +90,16 @@ private[hive] trait HiveClient {
/** Updates the given table with new metadata, optionally renaming the table. */
def alterTable(tableName: String, table: CatalogTable): Unit
+ /**
+ * Updates the given table with a new data schema and table properties, and keep everything else
+ * unchanged.
+ *
+ * TODO(cloud-fan): it's a little hacky to introduce the schema table properties here in
+ * `HiveClient`, but we don't have a cleaner solution now.
+ */
+ def alterTableDataSchema(
+ dbName: String, tableName: String, newDataSchema: StructType, schemaProps: Map[String, String])
+
/** Creates a new database with the given name. */
def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 387ec4f967233..ceeb9c1c0da6b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -22,7 +22,6 @@ import java.util.Locale
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import scala.language.reflectiveCalls
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -47,7 +46,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.execution.QueryExecutionException
-import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.hive.HiveExternalCatalog.{DATASOURCE_SCHEMA, DATASOURCE_SCHEMA_NUMPARTS, DATASOURCE_SCHEMA_PART_PREFIX}
import org.apache.spark.sql.hive.client.HiveClientImpl._
import org.apache.spark.sql.types._
import org.apache.spark.util.{CircularBuffer, Utils}
@@ -349,7 +348,7 @@ private[hive] class HiveClientImpl(
Option(client.getDatabase(dbName)).map { d =>
CatalogDatabase(
name = d.getName,
- description = d.getDescription,
+ description = Option(d.getDescription).getOrElse(""),
locationUri = CatalogUtils.stringToURI(d.getLocationUri),
properties = Option(d.getParameters).map(_.asScala.toMap).orNull)
}.getOrElse(throw new NoSuchDatabaseException(dbName))
@@ -462,6 +461,33 @@ private[hive] class HiveClientImpl(
shim.alterTable(client, qualifiedTableName, hiveTable)
}
+ override def alterTableDataSchema(
+ dbName: String,
+ tableName: String,
+ newDataSchema: StructType,
+ schemaProps: Map[String, String]): Unit = withHiveState {
+ val oldTable = client.getTable(dbName, tableName)
+ val hiveCols = newDataSchema.map(toHiveColumn)
+ oldTable.setFields(hiveCols.asJava)
+
+ // remove old schema table properties
+ val it = oldTable.getParameters.entrySet.iterator
+ while (it.hasNext) {
+ val entry = it.next()
+ val isSchemaProp = entry.getKey.startsWith(DATASOURCE_SCHEMA_PART_PREFIX) ||
+ entry.getKey == DATASOURCE_SCHEMA || entry.getKey == DATASOURCE_SCHEMA_NUMPARTS
+ if (isSchemaProp) {
+ it.remove()
+ }
+ }
+
+ // set new schema table properties
+ schemaProps.foreach { case (k, v) => oldTable.setProperty(k, v) }
+
+ val qualifiedTableName = s"$dbName.$tableName"
+ shim.alterTable(client, qualifiedTableName, oldTable)
+ }
+
override def createPartitions(
db: String,
table: String,
@@ -837,20 +863,7 @@ private[hive] object HiveClientImpl {
val (partCols, schema) = table.schema.map(toHiveColumn).partition { c =>
table.partitionColumnNames.contains(c.getName)
}
- // after SPARK-19279, it is not allowed to create a hive table with an empty schema,
- // so here we should not add a default col schema
- if (schema.isEmpty && DDLUtils.isDatasourceTable(table)) {
- // This is a hack to preserve existing behavior. Before Spark 2.0, we do not
- // set a default serde here (this was done in Hive), and so if the user provides
- // an empty schema Hive would automatically populate the schema with a single
- // field "col". However, after SPARK-14388, we set the default serde to
- // LazySimpleSerde so this implicit behavior no longer happens. Therefore,
- // we need to do it in Spark ourselves.
- hiveTable.setFields(
- Seq(new FieldSchema("col", "array", "from deserializer")).asJava)
- } else {
- hiveTable.setFields(schema.asJava)
- }
+ hiveTable.setFields(schema.asJava)
hiveTable.setPartCols(partCols.asJava)
userName.foreach(hiveTable.setOwner)
hiveTable.setCreateTime((table.createTime / 1000).toInt)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 7abb9f06b1310..056ffc667a5b7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -576,6 +576,18 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
hive.getFunctions(db, pattern).asScala
}
+ /**
+ * An extractor that matches all binary comparison operators except null-safe equality.
+ *
+ * Null-safe equality is not supported by Hive metastore partition predicate pushdown
+ */
+ object SpecialBinaryComparison {
+ def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match {
+ case _: EqualNullSafe => None
+ case _ => Some((e.left, e.right))
+ }
+ }
+
/**
* Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e.
* a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...".
@@ -590,14 +602,14 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
.map(col => col.getName).toSet
filters.collect {
- case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) =>
+ case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: IntegralType)) =>
s"${a.name} ${op.symbol} $v"
- case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) =>
+ case op @ SpecialBinaryComparison(Literal(v, _: IntegralType), a: Attribute) =>
s"$v ${op.symbol} ${a.name}"
- case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType))
+ case op @ SpecialBinaryComparison(a: Attribute, Literal(v, _: StringType))
if !varcharKeys.contains(a.name) =>
s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}"""
- case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute)
+ case op @ SpecialBinaryComparison(Literal(v, _: StringType), a: Attribute)
if !varcharKeys.contains(a.name) =>
s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}"""
}.mkString(" and ")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index e95f9ea480431..b8aa067cdb903 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException
import java.net.{URL, URLClassLoader}
import java.util
-import scala.language.reflectiveCalls
import scala.util.Try
import org.apache.commons.io.{FileUtils, IOUtils}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index 41c6b18e9d794..65e8b4e3c725c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -62,7 +62,7 @@ case class CreateHiveTableAsSelectCommand(
Map(),
query,
overwrite = false,
- ifNotExists = false)).toRdd
+ ifPartitionNotExists = false)).toRdd
} else {
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
@@ -78,7 +78,7 @@ case class CreateHiveTableAsSelectCommand(
Map(),
query,
overwrite = true,
- ifNotExists = false)).toRdd
+ ifPartitionNotExists = false)).toRdd
} catch {
case NonFatal(e) =>
// drop the created table.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
index ac735e8b383f6..4a7cd6901923b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
@@ -116,7 +116,7 @@ class HiveOutputWriter(
private val serializer = {
val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer]
- serializer.initialize(null, tableDesc.getProperties)
+ serializer.initialize(jobConf, tableDesc.getProperties)
serializer
}
@@ -130,7 +130,7 @@ class HiveOutputWriter(
private val standardOI = ObjectInspectorUtils
.getStandardObjectInspector(
- tableDesc.getDeserializer.getObjectInspector,
+ tableDesc.getDeserializer(jobConf).getObjectInspector,
ObjectInspectorCopyOption.JAVA)
.asInstanceOf[StructObjectInspector]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 666548d1a490b..2ce8ccfb35e0c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -30,13 +30,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
+import org.apache.spark.sql.catalyst.analysis.CastSupport
+import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.client.HiveClientImpl
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.util.Utils
@@ -50,14 +52,16 @@ import org.apache.spark.util.Utils
private[hive]
case class HiveTableScanExec(
requestedAttributes: Seq[Attribute],
- relation: CatalogRelation,
+ relation: HiveTableRelation,
partitionPruningPred: Seq[Expression])(
@transient private val sparkSession: SparkSession)
- extends LeafExecNode {
+ extends LeafExecNode with CastSupport {
require(partitionPruningPred.isEmpty || relation.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
+ override def conf: SQLConf = sparkSession.sessionState.conf
+
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -104,7 +108,7 @@ case class HiveTableScanExec(
hadoopConf)
private def castFromString(value: String, dataType: DataType) = {
- Cast(Literal(value), dataType).eval(null)
+ cast(Literal(value), dataType).eval(null)
}
private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {
@@ -205,8 +209,8 @@ case class HiveTableScanExec(
val input: AttributeSeq = relation.output
HiveTableScanExec(
requestedAttributes.map(QueryPlan.normalizeExprId(_, input)),
- relation.canonicalized.asInstanceOf[CatalogRelation],
- partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession)
+ relation.canonicalized.asInstanceOf[HiveTableRelation],
+ QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession)
}
override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 3682dc850790e..8032d7e728a04 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive.execution
-import java.io.IOException
+import java.io.{File, IOException}
import java.net.URI
import java.text.SimpleDateFormat
import java.util.{Date, Locale, Random}
@@ -71,14 +71,15 @@ import org.apache.spark.SparkException
* }}}.
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
- * @param ifNotExists If true, only write if the table or partition does not exist.
+ * @param ifPartitionNotExists If true, only write if the partition does not exist.
+ * Only valid for static partitions.
*/
case class InsertIntoHiveTable(
table: CatalogTable,
partition: Map[String, Option[String]],
query: LogicalPlan,
overwrite: Boolean,
- ifNotExists: Boolean) extends RunnableCommand {
+ ifPartitionNotExists: Boolean) extends RunnableCommand {
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
@@ -97,12 +98,24 @@ case class InsertIntoHiveTable(
val inputPathUri: URI = inputPath.toUri
val inputPathName: String = inputPathUri.getPath
val fs: FileSystem = inputPath.getFileSystem(hadoopConf)
- val stagingPathName: String =
+ var stagingPathName: String =
if (inputPathName.indexOf(stagingDir) == -1) {
new Path(inputPathName, stagingDir).toString
} else {
inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length)
}
+
+ // SPARK-20594: This is a walk-around fix to resolve a Hive bug. Hive requires that the
+ // staging directory needs to avoid being deleted when users set hive.exec.stagingdir
+ // under the table directory.
+ if (FileUtils.isSubDir(new Path(stagingPathName), inputPath, fs) &&
+ !stagingPathName.stripPrefix(inputPathName).stripPrefix(File.separator).startsWith(".")) {
+ logDebug(s"The staging dir '$stagingPathName' should be a child directory starts " +
+ "with '.' to avoid being deleted if we set hive.exec.stagingdir under the table " +
+ "directory.")
+ stagingPathName = new Path(inputPathName, ".hive-staging").toString
+ }
+
val dir: Path =
fs.makeQualified(
new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID))
@@ -342,7 +355,7 @@ case class InsertIntoHiveTable(
var doHiveOverwrite = overwrite
- if (oldPart.isEmpty || !ifNotExists) {
+ if (oldPart.isEmpty || !ifPartitionNotExists) {
// SPARK-18107: Insert overwrite runs much slower than hive-client.
// Newer Hive largely improves insert overwrite performance. As Spark uses older Hive
// version and we may not want to catch up new Hive version every time. We delete the
@@ -387,7 +400,13 @@ case class InsertIntoHiveTable(
// Attempt to delete the staging directory and the inclusive files. If failed, the files are
// expected to be dropped at the normal termination of VM since deleteOnExit is used.
try {
- createdTempDir.foreach { path => path.getFileSystem(hadoopConf).delete(path, true) }
+ createdTempDir.foreach { path =>
+ val fs = path.getFileSystem(hadoopConf)
+ if (fs.delete(path, true)) {
+ // If we successfully delete the staging directory, remove it from FileSystem's cache.
+ fs.cancelDeleteOnExit(path)
+ }
+ }
} catch {
case NonFatal(e) =>
logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index 3a34ec55c8b07..2defd319e0894 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -58,7 +58,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
OrcFileOperator.readSchema(
- files.map(_.getPath.toUri.toString),
+ files.map(_.getPath.toString),
Some(sparkSession.sessionState.newHadoopConf())
)
}
@@ -131,29 +131,27 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
(file: PartitionedFile) => {
val conf = broadcastedHadoopConf.value.value
+ val filePath = new Path(new URI(file.filePath))
+
// SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this
// case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file
// using the given physical schema. Instead, we simply return an empty iterator.
- val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf))
- if (maybePhysicalSchema.isEmpty) {
+ val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty
+ if (isEmptyFile) {
Iterator.empty
} else {
- val physicalSchema = maybePhysicalSchema.get
- OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema)
+ OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema)
val orcRecordReader = {
val job = Job.getInstance(conf)
FileInputFormat.setInputPaths(job, file.filePath)
- val fileSplit = new FileSplit(
- new Path(new URI(file.filePath)), file.start, file.length, Array.empty
- )
+ val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
// Custom OrcRecordReader is used to get
// ObjectInspector during recordReader creation itself and can
// avoid NameNode call in unwrapOrcStructs per file.
// Specifically would be helpful for partitioned datasets.
- val orcReader = OrcFile.createReader(
- new Path(new URI(file.filePath)), OrcFile.readerOptions(conf))
+ val orcReader = OrcFile.createReader(filePath, OrcFile.readerOptions(conf))
new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength)
}
@@ -163,6 +161,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
// Unwraps `OrcStruct`s to `UnsafeRow`s
OrcRelation.unwrapOrcStructs(
conf,
+ dataSchema,
requiredSchema,
Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]),
recordsIterator)
@@ -272,25 +271,32 @@ private[orc] object OrcRelation extends HiveInspectors {
def unwrapOrcStructs(
conf: Configuration,
dataSchema: StructType,
+ requiredSchema: StructType,
maybeStructOI: Option[StructObjectInspector],
iterator: Iterator[Writable]): Iterator[InternalRow] = {
val deserializer = new OrcSerde
- val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType))
- val unsafeProjection = UnsafeProjection.create(dataSchema)
+ val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType))
+ val unsafeProjection = UnsafeProjection.create(requiredSchema)
def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = {
- val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map {
- case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal
+ val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map {
+ case (field, ordinal) =>
+ var ref = oi.getStructFieldRef(field.name)
+ if (ref == null) {
+ ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name))
+ }
+ ref -> ordinal
}.unzip
- val unwrappers = fieldRefs.map(unwrapperFor)
+ val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r))
iterator.map { value =>
val raw = deserializer.deserialize(value)
var i = 0
val length = fieldRefs.length
while (i < length) {
- val fieldValue = oi.getStructFieldData(raw, fieldRefs(i))
+ val fieldRef = fieldRefs(i)
+ val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef)
if (fieldValue == null) {
mutableRow.setNullAt(fieldOrdinals(i))
} else {
@@ -306,8 +312,8 @@ private[orc] object OrcRelation extends HiveInspectors {
}
def setRequiredColumns(
- conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = {
- val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer)
+ conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = {
+ val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer)
val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip
HiveShim.appendReadColumns(conf, sortedIDs, sortedNames)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index d9bb1f8c7edcc..4612cce80effd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.test
import java.io.File
+import java.net.URI
import java.util.{Set => JavaSet}
import scala.collection.JavaConverters._
@@ -486,16 +487,16 @@ private[hive] class TestHiveSparkSession(
}
}
+ // Clean out the Hive warehouse between each suite
+ val warehouseDir = new File(new URI(sparkContext.conf.get("spark.sql.warehouse.dir")).getPath)
+ Utils.deleteRecursively(warehouseDir)
+ warehouseDir.mkdir()
+
sharedState.cacheManager.clearCache()
loadedTables.clear()
- sessionState.catalog.clearTempTables()
- sessionState.catalog.tableRelationCache.invalidateAll()
-
+ sessionState.catalog.reset()
metadataHive.reset()
- FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)).
- foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) }
-
// HDFS root scratch dir requires the write all (733) permission. For each connecting user,
// an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with
// ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index d3cbf898e2439..48ab4eb9a6178 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -102,14 +102,18 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
}
test("uncache of nonexistant tables") {
+ val expectedErrorMsg = "Table or view not found: nonexistantTable"
// make sure table doesn't exist
- intercept[NoSuchTableException](spark.table("nonexistantTable"))
- intercept[NoSuchTableException] {
+ var e = intercept[AnalysisException](spark.table("nonexistantTable")).getMessage
+ assert(e.contains(expectedErrorMsg))
+ e = intercept[AnalysisException] {
spark.catalog.uncacheTable("nonexistantTable")
- }
- intercept[NoSuchTableException] {
+ }.getMessage
+ assert(e.contains(expectedErrorMsg))
+ e = intercept[AnalysisException] {
sql("UNCACHE TABLE nonexistantTable")
- }
+ }.getMessage
+ assert(e.contains(expectedErrorMsg))
sql("UNCACHE TABLE IF EXISTS nonexistantTable")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
index 59cc6605a1243..43ce093f8a7dc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
@@ -615,6 +615,25 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle
assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
}
+
+ withSQLConf("hive.default.fileformat" -> "orc") {
+ val (desc, exists) = extractTableDesc(
+ "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS textfile")
+ assert(exists)
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
+ assert(desc.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ }
+
+ withSQLConf("hive.default.fileformat" -> "orc") {
+ val (desc, exists) = extractTableDesc(
+ "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS sequencefile")
+ assert(exists)
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.SequenceFileInputFormat"))
+ assert(desc.storage.outputFormat == Some("org.apache.hadoop.mapred.SequenceFileOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ }
}
test("table name with schema") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
deleted file mode 100644
index 705d43f1f3aba..0000000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
+++ /dev/null
@@ -1,264 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.net.URI
-
-import org.apache.hadoop.fs.Path
-import org.scalatest.BeforeAndAfterEach
-
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
-import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
-
-
-class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
- with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach {
-
- // To test `HiveExternalCatalog`, we need to read/write the raw table meta from/to hive client.
- val hiveClient: HiveClient =
- spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
-
- val tempDir = Utils.createTempDir().getCanonicalFile
- val tempDirUri = tempDir.toURI
- val tempDirStr = tempDir.getAbsolutePath
-
- override def beforeEach(): Unit = {
- sql("CREATE DATABASE test_db")
- for ((tbl, _) <- rawTablesAndExpectations) {
- hiveClient.createTable(tbl, ignoreIfExists = false)
- }
- }
-
- override def afterEach(): Unit = {
- Utils.deleteRecursively(tempDir)
- hiveClient.dropDatabase("test_db", ignoreIfNotExists = false, cascade = true)
- }
-
- private def getTableMetadata(tableName: String): CatalogTable = {
- spark.sharedState.externalCatalog.getTable("test_db", tableName)
- }
-
- private def defaultTableURI(tableName: String): URI = {
- spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db")))
- }
-
- // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark
- // versions prior to 2.1 would generate almost same raw table metadata for a specific table.
- val simpleSchema = new StructType().add("i", "int")
- val partitionedSchema = new StructType().add("i", "int").add("j", "int")
-
- lazy val hiveTable = CatalogTable(
- identifier = TableIdentifier("tbl1", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
- schema = simpleSchema)
-
- lazy val externalHiveTable = CatalogTable(
- identifier = TableIdentifier("tbl2", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(tempDirUri),
- inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
- schema = simpleSchema)
-
- lazy val partitionedHiveTable = CatalogTable(
- identifier = TableIdentifier("tbl3", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
- schema = partitionedSchema,
- partitionColumnNames = Seq("j"))
-
-
- val simpleSchemaJson =
- """
- |{
- | "type": "struct",
- | "fields": [{
- | "name": "i",
- | "type": "integer",
- | "nullable": true,
- | "metadata": {}
- | }]
- |}
- """.stripMargin
-
- val partitionedSchemaJson =
- """
- |{
- | "type": "struct",
- | "fields": [{
- | "name": "i",
- | "type": "integer",
- | "nullable": true,
- | "metadata": {}
- | },
- | {
- | "name": "j",
- | "type": "integer",
- | "nullable": true,
- | "metadata": {}
- | }]
- |}
- """.stripMargin
-
- lazy val dataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl4", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> defaultTableURI("tbl4").toString)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map(
- "spark.sql.sources.provider" -> "json",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val hiveCompatibleDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl5", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> defaultTableURI("tbl5").toString)),
- schema = simpleSchema,
- provider = Some("parquet"),
- properties = Map(
- "spark.sql.sources.provider" -> "parquet",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val partitionedDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl6", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> defaultTableURI("tbl6").toString)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map(
- "spark.sql.sources.provider" -> "json",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> partitionedSchemaJson,
- "spark.sql.sources.schema.numPartCols" -> "1",
- "spark.sql.sources.schema.partCol.0" -> "j"))
-
- lazy val externalDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl7", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")),
- properties = Map("path" -> tempDirStr)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map(
- "spark.sql.sources.provider" -> "json",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val hiveCompatibleExternalDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl8", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(tempDirUri),
- properties = Map("path" -> tempDirStr)),
- schema = simpleSchema,
- properties = Map(
- "spark.sql.sources.provider" -> "parquet",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val dataSourceTableWithoutSchema = CatalogTable(
- identifier = TableIdentifier("tbl9", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")),
- properties = Map("path" -> tempDirStr)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map("spark.sql.sources.provider" -> "json"))
-
- // A list of all raw tables we want to test, with their expected schema.
- lazy val rawTablesAndExpectations = Seq(
- hiveTable -> simpleSchema,
- externalHiveTable -> simpleSchema,
- partitionedHiveTable -> partitionedSchema,
- dataSourceTable -> simpleSchema,
- hiveCompatibleDataSourceTable -> simpleSchema,
- partitionedDataSourceTable -> partitionedSchema,
- externalDataSourceTable -> simpleSchema,
- hiveCompatibleExternalDataSourceTable -> simpleSchema,
- dataSourceTableWithoutSchema -> new StructType())
-
- test("make sure we can read table created by old version of Spark") {
- for ((tbl, expectedSchema) <- rawTablesAndExpectations) {
- val readBack = getTableMetadata(tbl.identifier.table)
- assert(readBack.schema.sameType(expectedSchema))
-
- if (tbl.tableType == CatalogTableType.EXTERNAL) {
- // trim the URI prefix
- val tableLocation = readBack.storage.locationUri.get.getPath
- val expectedLocation = tempDir.toURI.getPath.stripSuffix("/")
- assert(tableLocation == expectedLocation)
- }
- }
- }
-
- test("make sure we can alter table location created by old version of Spark") {
- withTempDir { dir =>
- for ((tbl, _) <- rawTablesAndExpectations if tbl.tableType == CatalogTableType.EXTERNAL) {
- val path = dir.toURI.toString.stripSuffix("/")
- sql(s"ALTER TABLE ${tbl.identifier} SET LOCATION '$path'")
-
- val readBack = getTableMetadata(tbl.identifier.table)
-
- // trim the URI prefix
- val actualTableLocation = readBack.storage.locationUri.get.getPath
- val expected = dir.toURI.getPath.stripSuffix("/")
- assert(actualTableLocation == expected)
- }
- }
- }
-
- test("make sure we can rename table created by old version of Spark") {
- for ((tbl, expectedSchema) <- rawTablesAndExpectations) {
- val newName = tbl.identifier.table + "_renamed"
- sql(s"ALTER TABLE ${tbl.identifier} RENAME TO $newName")
-
- val readBack = getTableMetadata(newName)
- assert(readBack.schema.sameType(expectedSchema))
-
- // trim the URI prefix
- val actualTableLocation = readBack.storage.locationUri.get.getPath
- val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) {
- tempDir.toURI.getPath.stripSuffix("/")
- } else {
- // trim the URI prefix
- defaultTableURI(newName).getPath
- }
- assert(actualTableLocation == expectedLocation)
- }
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
index bd54c043c6ec4..0a522b6a11c80 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
@@ -63,4 +63,54 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite {
assert(!rawTable.properties.contains(HiveExternalCatalog.DATASOURCE_PROVIDER))
assert(DDLUtils.isHiveTable(externalCatalog.getTable("db1", "hive_tbl")))
}
+
+ Seq("parquet", "hive").foreach { format =>
+ test(s"Partition columns should be put at the end of table schema for the format $format") {
+ val catalog = newBasicCatalog()
+ val newSchema = new StructType()
+ .add("col1", "int")
+ .add("col2", "string")
+ .add("partCol1", "int")
+ .add("partCol2", "string")
+ val table = CatalogTable(
+ identifier = TableIdentifier("tbl", Some("db1")),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType()
+ .add("col1", "int")
+ .add("partCol1", "int")
+ .add("partCol2", "string")
+ .add("col2", "string"),
+ provider = Some(format),
+ partitionColumnNames = Seq("partCol1", "partCol2"))
+ catalog.createTable(table, ignoreIfExists = false)
+
+ val restoredTable = externalCatalog.getTable("db1", "tbl")
+ assert(restoredTable.schema == newSchema)
+ }
+ }
+
+ test("SPARK-22306: alter table schema should not erase the bucketing metadata at hive side") {
+ val catalog = newBasicCatalog()
+ externalCatalog.client.runSqlHive(
+ """
+ |CREATE TABLE db1.t(a string, b string)
+ |CLUSTERED BY (a, b) SORTED BY (a, b) INTO 10 BUCKETS
+ |STORED AS PARQUET
+ """.stripMargin)
+
+ val newSchema = new StructType().add("a", "string").add("b", "string").add("c", "string")
+ catalog.alterTableDataSchema("db1", "t", newSchema)
+
+ assert(catalog.getTable("db1", "t").schema == newSchema)
+ val bucketString = externalCatalog.client.runSqlHive("DESC FORMATTED db1.t")
+ .filter(_.contains("Num Buckets")).head
+ assert(bucketString.contains("10"))
+ }
+
+ test("SPARK-23001: NullPointerException when running desc database") {
+ val catalog = newBasicCatalog()
+ catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false)
+ assert(catalog.getDatabase("dbWithNullDesc").description == "")
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
new file mode 100644
index 0000000000000..e6a6cac358efa
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.File
+import java.nio.file.Files
+
+import scala.sys.process._
+
+import org.apache.spark.TestUtils
+import org.apache.spark.sql.{QueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.CatalogTableType
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
+
+/**
+ * Test HiveExternalCatalog backward compatibility.
+ *
+ * Note that, this test suite will automatically download spark binary packages of different
+ * versions to a local directory `/tmp/spark-test`. If there is already a spark folder with
+ * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the
+ * downloading for this spark version.
+ */
+class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
+ private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse")
+ private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data")
+ // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to
+ // avoid downloading Spark of different versions in each run.
+ private val sparkTestingDir = new File("/tmp/test-spark")
+ private val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(wareHousePath)
+ Utils.deleteRecursively(tmpDataDir)
+ Utils.deleteRecursively(sparkTestingDir)
+ super.afterAll()
+ }
+
+ private def tryDownloadSpark(version: String, path: String): Unit = {
+ // Try a few mirrors first; fall back to Apache archive
+ val mirrors =
+ (0 until 2).flatMap { _ =>
+ try {
+ Some(Seq("wget",
+ "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim)
+ } catch {
+ // If we can't get a mirror URL, skip it. No retry.
+ case _: Exception => None
+ }
+ }
+ val sites = mirrors.distinct :+ "https://archive.apache.org/dist"
+ logInfo(s"Trying to download Spark $version from $sites")
+ for (site <- sites) {
+ val filename = s"spark-$version-bin-hadoop2.7.tgz"
+ val url = s"$site/spark/spark-$version/$filename"
+ logInfo(s"Downloading Spark $version from $url")
+ if (Seq("wget", url, "-q", "-P", path).! == 0) {
+ val downloaded = new File(sparkTestingDir, filename).getCanonicalPath
+ val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath
+
+ Seq("mkdir", targetDir).!
+ val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").!
+ Seq("rm", downloaded).!
+
+ // For a corrupted file, `tar` returns non-zero values. However, we also need to check
+ // the extracted file because `tar` returns 0 for empty file.
+ val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit")
+ if (exitCode == 0 && sparkSubmit.exists()) {
+ return
+ } else {
+ Seq("rm", "-rf", targetDir).!
+ }
+ }
+ logWarning(s"Failed to download Spark $version from $url")
+ }
+ fail(s"Unable to download Spark $version")
+ }
+
+ private def genDataDir(name: String): String = {
+ new File(tmpDataDir, name).getCanonicalPath
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val tempPyFile = File.createTempFile("test", ".py")
+ // scalastyle:off line.size.limit
+ Files.write(tempPyFile.toPath,
+ s"""
+ |from pyspark.sql import SparkSession
+ |import os
+ |
+ |spark = SparkSession.builder.enableHiveSupport().getOrCreate()
+ |version_index = spark.conf.get("spark.sql.test.version.index", None)
+ |
+ |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index))
+ |
+ |spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index))
+ |
+ |json_file = "${genDataDir("json_")}" + str(version_index)
+ |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file)
+ |spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file))
+ |
+ |parquet_file = "${genDataDir("parquet_")}" + str(version_index)
+ |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file)
+ |spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file))
+ |
+ |json_file2 = "${genDataDir("json2_")}" + str(version_index)
+ |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2)
+ |spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2))
+ |
+ |parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index)
+ |spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1"))
+ |spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2))
+ |
+ |spark.sql("create view v_{} as select 1 i".format(version_index))
+ """.stripMargin.getBytes("utf8"))
+ // scalastyle:on line.size.limit
+
+ PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) =>
+ val sparkHome = new File(sparkTestingDir, s"spark-$version")
+ if (!sparkHome.exists()) {
+ tryDownloadSpark(version, sparkTestingDir.getCanonicalPath)
+ }
+
+ val args = Seq(
+ "--name", "prepare testing tables",
+ "--master", "local[2]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}",
+ "--conf", s"spark.sql.test.version.index=$index",
+ "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}",
+ tempPyFile.getCanonicalPath)
+ runSparkSubmit(args, Some(sparkHome.getCanonicalPath))
+ }
+
+ tempPyFile.delete()
+ }
+
+ test("backward compatibility") {
+ val args = Seq(
+ "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"),
+ "--name", "HiveExternalCatalog backward compatibility test",
+ "--master", "local[2]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}",
+ "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}",
+ unusedJar.toString)
+ runSparkSubmit(args)
+ }
+}
+
+object PROCESS_TABLES extends QueryTest with SQLTestUtils {
+ // Tests the latest version of every release line.
+ val testingVersions = Seq("2.0.2", "2.1.3", "2.2.2")
+
+ protected var spark: SparkSession = _
+
+ def main(args: Array[String]): Unit = {
+ val session = SparkSession.builder()
+ .enableHiveSupport()
+ .getOrCreate()
+ spark = session
+ import session.implicits._
+
+ testingVersions.indices.foreach { index =>
+ Seq(
+ s"data_source_tbl_$index",
+ s"hive_compatible_data_source_tbl_$index",
+ s"external_data_source_tbl_$index",
+ s"hive_compatible_external_data_source_tbl_$index",
+ s"external_table_without_schema_$index").foreach { tbl =>
+ val tableMeta = spark.sharedState.externalCatalog.getTable("default", tbl)
+
+ // make sure we can insert and query these tables.
+ session.sql(s"insert into $tbl select 2")
+ checkAnswer(session.sql(s"select * from $tbl"), Row(1) :: Row(2) :: Nil)
+ checkAnswer(session.sql(s"select i from $tbl where i > 1"), Row(2))
+
+ // make sure we can rename table.
+ val newName = tbl + "_renamed"
+ sql(s"ALTER TABLE $tbl RENAME TO $newName")
+ val readBack = spark.sharedState.externalCatalog.getTable("default", newName)
+
+ val actualTableLocation = readBack.storage.locationUri.get.getPath
+ val expectedLocation = if (tableMeta.tableType == CatalogTableType.EXTERNAL) {
+ tableMeta.storage.locationUri.get.getPath
+ } else {
+ spark.sessionState.catalog.defaultTablePath(TableIdentifier(newName, None)).getPath
+ }
+ assert(actualTableLocation == expectedLocation)
+
+ // make sure we can alter table location.
+ withTempDir { dir =>
+ val path = dir.toURI.toString.stripSuffix("/")
+ sql(s"ALTER TABLE ${tbl}_renamed SET LOCATION '$path'")
+ val readBack = spark.sharedState.externalCatalog.getTable("default", tbl + "_renamed")
+ val actualTableLocation = readBack.storage.locationUri.get.getPath
+ val expected = dir.toURI.getPath.stripSuffix("/")
+ assert(actualTableLocation == expected)
+ }
+ }
+
+ // test permanent view
+ checkAnswer(sql(s"select i from v_$index"), Row(1))
+
+ // SPARK-22356: overlapped columns between data and partition schema in data source tables
+ val tbl_with_col_overlap = s"tbl_with_col_overlap_$index"
+ // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0.
+ if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") {
+ spark.sql("msck repair table " + tbl_with_col_overlap)
+ assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p"))
+ checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil)
+ assert(sql("desc " + tbl_with_col_overlap).select("col_name")
+ .as[String].collect().mkString(",").contains("i,j,p"))
+ } else {
+ assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j"))
+ checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil)
+ assert(sql("desc " + tbl_with_col_overlap).select("col_name")
+ .as[String].collect().mkString(",").contains("i,p,j"))
+ }
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
index 0c28a1b609bb8..e71aba72c31fe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
@@ -31,14 +31,22 @@ import org.apache.spark.sql.test.SQLTestUtils
class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
test("SPARK-16337 temporary view refresh") {
- withTempView("view_refresh") {
+ checkRefreshView(isTemp = true)
+ }
+
+ test("view refresh") {
+ checkRefreshView(isTemp = false)
+ }
+
+ private def checkRefreshView(isTemp: Boolean) {
+ withView("view_refresh") {
withTable("view_table") {
// Create a Parquet directory
spark.range(start = 0, end = 100, step = 1, numPartitions = 3)
.write.saveAsTable("view_table")
- // Read the table in
- spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh")
+ val temp = if (isTemp) "TEMPORARY" else ""
+ spark.sql(s"CREATE $temp VIEW view_refresh AS SELECT * FROM view_table WHERE id > -1")
assert(sql("select count(*) from view_refresh").first().getLong(0) == 100)
// Delete a file using the Hadoop file system interface since the path returned by
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
index 319d02613f00a..d271acc63de08 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
@@ -46,7 +46,7 @@ class HiveSchemaInferenceSuite
override def afterEach(): Unit = {
super.afterEach()
- spark.sessionState.catalog.tableRelationCache.invalidateAll()
+ spark.sessionState.catalog.invalidateAllCachedTables()
FileStatusCache.resetForTesting()
}
@@ -104,7 +104,7 @@ class HiveSchemaInferenceSuite
identifier = TableIdentifier(table = TEST_TABLE_NAME, database = Option(DATABASE)),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat(
- locationUri = Option(new java.net.URI(dir.getAbsolutePath)),
+ locationUri = Option(dir.toURI),
inputFormat = serde.inputFormat,
outputFormat = serde.outputFormat,
serde = serde.serde,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 5f15a705a2e99..cf145c845eef0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -18,17 +18,11 @@
package org.apache.spark.sql.hive
import java.io.{BufferedWriter, File, FileWriter}
-import java.sql.Timestamp
-import java.util.Date
-import scala.collection.mutable.ArrayBuffer
import scala.tools.nsc.Properties
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfterEach, Matchers}
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.exceptions.TestFailedDueToTimeoutException
-import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -38,7 +32,6 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
-import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.util.{ResetSystemProperties, Utils}
@@ -46,11 +39,10 @@ import org.apache.spark.util.{ResetSystemProperties, Utils}
* This suite tests spark-submit with applications using HiveContext.
*/
class HiveSparkSubmitSuite
- extends SparkFunSuite
+ extends SparkSubmitTestUtils
with Matchers
with BeforeAndAfterEach
- with ResetSystemProperties
- with Timeouts {
+ with ResetSystemProperties {
// TODO: rewrite these or mark them as slow tests to be run sparingly
@@ -333,71 +325,6 @@ class HiveSparkSubmitSuite
unusedJar.toString)
runSparkSubmit(argsForShowTables)
}
-
- // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
- // This is copied from org.apache.spark.deploy.SparkSubmitSuite
- private def runSparkSubmit(args: Seq[String]): Unit = {
- val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- val history = ArrayBuffer.empty[String]
- val sparkSubmit = if (Utils.isWindows) {
- // On Windows, `ProcessBuilder.directory` does not change the current working directory.
- new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath
- } else {
- "./bin/spark-submit"
- }
- val commands = Seq(sparkSubmit) ++ args
- val commandLine = commands.mkString("'", "' '", "'")
-
- val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome))
- val env = builder.environment()
- env.put("SPARK_TESTING", "1")
- env.put("SPARK_HOME", sparkHome)
-
- def captureOutput(source: String)(line: String): Unit = {
- // This test suite has some weird behaviors when executed on Jenkins:
- //
- // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a
- // timestamp to provide more diagnosis information.
- // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print
- // them out for debugging purposes.
- val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line"
- // scalastyle:off println
- println(logLine)
- // scalastyle:on println
- history += logLine
- }
-
- val process = builder.start()
- new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start()
- new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start()
-
- try {
- val exitCode = failAfter(300.seconds) { process.waitFor() }
- if (exitCode != 0) {
- // include logs in output. Note that logging is async and may not have completed
- // at the time this exception is raised
- Thread.sleep(1000)
- val historyLog = history.mkString("\n")
- fail {
- s"""spark-submit returned with exit code $exitCode.
- |Command line: $commandLine
- |
- |$historyLog
- """.stripMargin
- }
- }
- } catch {
- case to: TestFailedDueToTimeoutException =>
- val historyLog = history.mkString("\n")
- fail(s"Timeout of $commandLine" +
- s" See the log4j logs for more detail." +
- s"\n$historyLog", to)
- case t: Throwable => throw t
- } finally {
- // Ensure we still kill the process in case it timed out
- process.destroy()
- }
- }
}
object SetMetastoreURLTest extends Logging {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index d6999af84eac0..d69615669348b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -166,72 +166,54 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("DROP TABLE tmp_table")
}
- test("INSERT OVERWRITE - partition IF NOT EXISTS") {
- withTempDir { tmpDir =>
- val table = "table_with_partition"
- withTable(table) {
- val selQuery = s"select c1, p1, p2 from $table"
- sql(
- s"""
- |CREATE TABLE $table(c1 string)
- |PARTITIONED by (p1 string,p2 string)
- |location '${tmpDir.toURI.toString}'
- """.stripMargin)
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2='b')
- |SELECT 'blarr'
- """.stripMargin)
- checkAnswer(
- sql(selQuery),
- Row("blarr", "a", "b"))
-
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2='b')
- |SELECT 'blarr2'
- """.stripMargin)
- checkAnswer(
- sql(selQuery),
- Row("blarr2", "a", "b"))
+ testPartitionedTable("INSERT OVERWRITE - partition IF NOT EXISTS") { tableName =>
+ val selQuery = s"select a, b, c, d from $tableName"
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c=3)
+ |SELECT 1, 4
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(1, 2, 3, 4))
- var e = intercept[AnalysisException] {
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2) IF NOT EXISTS
- |SELECT 'blarr3', 'newPartition'
- """.stripMargin)
- }
- assert(e.getMessage.contains(
- "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]"))
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c=3)
+ |SELECT 5, 6
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(5, 2, 3, 6))
+
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c) IF NOT EXISTS
+ |SELECT 7, 8, 3
+ """.stripMargin)
+ }
+ assert(e.getMessage.contains(
+ "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]"))
- e = intercept[AnalysisException] {
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2) IF NOT EXISTS
- |SELECT 'blarr3', 'b'
- """.stripMargin)
- }
- assert(e.getMessage.contains(
- "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]"))
+ // If the partition already exists, the insert will overwrite the data
+ // unless users specify IF NOT EXISTS
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=2, c=3) IF NOT EXISTS
+ |SELECT 9, 10
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(5, 2, 3, 6))
- // If the partition already exists, the insert will overwrite the data
- // unless users specify IF NOT EXISTS
- sql(
- s"""
- |INSERT OVERWRITE TABLE $table
- |partition (p1='a',p2='b') IF NOT EXISTS
- |SELECT 'blarr3'
- """.stripMargin)
- checkAnswer(
- sql(selQuery),
- Row("blarr2", "a", "b"))
- }
- }
+ // ADD PARTITION has the same effect, even if no actual data is inserted.
+ sql(s"ALTER TABLE $tableName ADD PARTITION (b=21, c=31)")
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $tableName
+ |partition (b=21, c=31) IF NOT EXISTS
+ |SELECT 20, 24
+ """.stripMargin)
+ checkAnswer(sql(selQuery), Row(5, 2, 3, 6))
}
test("Insert ArrayType.containsNull == false") {
@@ -486,6 +468,28 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
}
}
+ test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") {
+ withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
+ withTable("tab1", "tab2") {
+ Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1")
+
+ spark.sql(
+ """
+ |CREATE TABLE tab2 (word string, length int)
+ |PARTITIONED BY (first string)
+ """.stripMargin)
+
+ spark.sql(
+ """
+ |INSERT INTO TABLE tab2 PARTITION(first)
+ |SELECT word, length, cast(first as string) as first FROM tab1
+ """.stripMargin)
+
+ checkAnswer(spark.table("tab2"), Row("a", 3, "b"))
+ }
+ }
+ }
+
testPartitionedTable("insertInto() should reject extra columns") {
tableName =>
sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)")
@@ -494,4 +498,15 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
spark.table("t").write.insertInto(tableName)
}
}
+
+ test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") {
+ // Set hive.exec.stagingdir under the table directory without start with ".".
+ withSQLConf("hive.exec.stagingdir" -> "./test") {
+ withTable("test_table") {
+ sql("CREATE TABLE test_table (key int)")
+ sql("INSERT OVERWRITE TABLE test_table SELECT 1")
+ checkAnswer(sql("SELECT * FROM test_table"), Row(1))
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index b554694815571..d62ed1923c91a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -746,7 +746,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
val hiveTable = CatalogTable(
identifier = TableIdentifier(tableName, Some("default")),
tableType = CatalogTableType.MANAGED,
- schema = new StructType,
+ schema = HiveExternalCatalog.EMPTY_DATA_SCHEMA,
provider = Some("json"),
storage = CatalogStorageFormat(
locationUri = None,
@@ -998,7 +998,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
spark.sql("""drop database if exists testdb8156 CASCADE""")
}
-
test("skip hive metadata on table creation") {
withTempDir { tempPath =>
val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType)))
@@ -1272,7 +1271,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
val hiveTable = CatalogTable(
identifier = TableIdentifier("t", Some("default")),
tableType = CatalogTableType.MANAGED,
- schema = new StructType,
+ schema = HiveExternalCatalog.EMPTY_DATA_SCHEMA,
provider = Some("json"),
storage = CatalogStorageFormat.empty,
properties = Map(
@@ -1350,6 +1349,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
}
}
+ Seq("orc", "parquet", "csv", "json", "text").foreach { format =>
+ test(s"SPARK-22146: read files containing special characters using $format") {
+ val nameWithSpecialChars = s"sp&cial%chars"
+ withTempDir { dir =>
+ val tmpFile = s"$dir/$nameWithSpecialChars"
+ spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile)
+ val fileContent = spark.read.format(format).load(tmpFile)
+ checkAnswer(fileContent, Seq(Row("a"), Row("b")))
+ }
+ }
+ }
+
private def withDebugMode(f: => Unit): Unit = {
val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE)
try {
@@ -1359,30 +1370,4 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue)
}
}
-
- test("SPARK-18464: support old table which doesn't store schema in table properties") {
- withTable("old") {
- withTempPath { path =>
- Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath)
- val tableDesc = CatalogTable(
- identifier = TableIdentifier("old", Some("default")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> path.getAbsolutePath)
- ),
- schema = new StructType(),
- provider = Some("parquet"),
- properties = Map(
- HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet"))
- hiveClient.createTable(tableDesc, ignoreIfExists = false)
-
- checkAnswer(spark.table("old"), Row(1, "a"))
-
- val expectedSchema = StructType(Seq(
- StructField("i", IntegerType, nullable = true),
- StructField("j", StringType, nullable = true)))
- assert(table("old").schema === expectedSchema)
- }
- }
- }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
index 43b6bf5feeb60..b2dc401ce1efc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive
import java.io.File
+import java.sql.Timestamp
import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem
@@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl
sql("DROP TABLE IF EXISTS createAndInsertTest")
}
}
+
+ test("SPARK-21739: Cast expression should initialize timezoneId") {
+ withTable("table_with_timestamp_partition") {
+ sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
+ sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
+ "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")
+
+ // test for Cast expression in TableReader
+ checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
+ Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))
+
+ // test for Cast expression in HiveTableScanExec
+ checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
+ "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala
new file mode 100644
index 0000000000000..4b28d4f362b80
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.File
+import java.sql.Timestamp
+import java.util.Date
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
+import org.apache.spark.util.Utils
+
+trait SparkSubmitTestUtils extends SparkFunSuite with Timeouts {
+
+ // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
+ // This is copied from org.apache.spark.deploy.SparkSubmitSuite
+ protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = {
+ val sparkHome = sparkHomeOpt.getOrElse(
+ sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")))
+ val history = ArrayBuffer.empty[String]
+ val sparkSubmit = if (Utils.isWindows) {
+ // On Windows, `ProcessBuilder.directory` does not change the current working directory.
+ new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath
+ } else {
+ "./bin/spark-submit"
+ }
+ val commands = Seq(sparkSubmit) ++ args
+ val commandLine = commands.mkString("'", "' '", "'")
+
+ val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome))
+ val env = builder.environment()
+ env.put("SPARK_TESTING", "1")
+ env.put("SPARK_HOME", sparkHome)
+
+ def captureOutput(source: String)(line: String): Unit = {
+ // This test suite has some weird behaviors when executed on Jenkins:
+ //
+ // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a
+ // timestamp to provide more diagnosis information.
+ // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print
+ // them out for debugging purposes.
+ val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line"
+ // scalastyle:off println
+ println(logLine)
+ // scalastyle:on println
+ history += logLine
+ }
+
+ val process = builder.start()
+ new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start()
+ new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start()
+
+ try {
+ val exitCode = failAfter(300.seconds) { process.waitFor() }
+ if (exitCode != 0) {
+ // include logs in output. Note that logging is async and may not have completed
+ // at the time this exception is raised
+ Thread.sleep(1000)
+ val historyLog = history.mkString("\n")
+ fail {
+ s"""spark-submit returned with exit code $exitCode.
+ |Command line: $commandLine
+ |
+ |$historyLog
+ """.stripMargin
+ }
+ }
+ } catch {
+ case to: TestFailedDueToTimeoutException =>
+ val historyLog = history.mkString("\n")
+ fail(s"Timeout of $commandLine" +
+ s" See the log4j logs for more detail." +
+ s"\n$historyLog", to)
+ case t: Throwable => throw t
+ } finally {
+ // Ensure we still kill the process in case it timed out
+ process.destroy()
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 3191b9975fbf9..a9caad897c589 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
+import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.joins._
@@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+
class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton {
test("Hive serde tables should fallback to HDFS for size estimation") {
@@ -59,7 +60,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
|LOCATION '${tempDir.toURI}'""".stripMargin)
val relation = spark.table("csv_table").queryExecution.analyzed.children.head
- .asInstanceOf[CatalogRelation]
+ .asInstanceOf[HiveTableRelation]
val properties = relation.tableMeta.properties
assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0")
@@ -125,6 +126,77 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false)
}
+ test("SPARK-21079 - analyze table with location different than that of individual partitions") {
+ def queryTotalSize(tableName: String): BigInt =
+ spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes
+
+ val tableName = "analyzeTable_part"
+ withTable(tableName) {
+ withTempPath { path =>
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)")
+
+ val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03")
+ partitionDates.foreach { ds =>
+ sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src")
+ }
+
+ sql(s"ALTER TABLE $tableName SET LOCATION '$path'")
+
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
+
+ assert(queryTotalSize(tableName) === BigInt(17436))
+ }
+ }
+ }
+
+ test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") {
+ def queryTotalSize(tableName: String): BigInt =
+ spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes
+
+ val sourceTableName = "analyzeTable_part"
+ val tableName = "analyzeTable_part_vis"
+ withTable(sourceTableName, tableName) {
+ withTempPath { path =>
+ // Create a table with 3 partitions all located under a single top-level directory 'path'
+ sql(
+ s"""
+ |CREATE TABLE $sourceTableName (key STRING, value STRING)
+ |PARTITIONED BY (ds STRING)
+ |LOCATION '$path'
+ """.stripMargin)
+
+ val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03")
+ partitionDates.foreach { ds =>
+ sql(
+ s"""
+ |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds')
+ |SELECT * FROM src
+ """.stripMargin)
+ }
+
+ // Create another table referring to the same location
+ sql(
+ s"""
+ |CREATE TABLE $tableName (key STRING, value STRING)
+ |PARTITIONED BY (ds STRING)
+ |LOCATION '$path'
+ """.stripMargin)
+
+ // Register only one of the partitions found on disk
+ val ds = partitionDates.head
+ sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect()
+
+ // Analyze original table - expect 3 partitions
+ sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan")
+ assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812))
+
+ // Analyze partial-copy table - expect only 1 partition
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
+ assert(queryTotalSize(tableName) === BigInt(5812))
+ }
+ }
+ }
+
test("analyzing views is not supported") {
def assertAnalyzeUnsupported(analyzeCommand: String): Unit = {
val err = intercept[AnalysisException] {
@@ -145,23 +217,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}
}
- private def checkTableStats(
- tableName: String,
- hasSizeInBytes: Boolean,
- expectedRowCounts: Option[Int]): Option[CatalogStatistics] = {
- val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats
-
- if (hasSizeInBytes || expectedRowCounts.nonEmpty) {
- assert(stats.isDefined)
- assert(stats.get.sizeInBytes > 0)
- assert(stats.get.rowCount === expectedRowCounts)
- } else {
- assert(stats.isEmpty)
- }
-
- stats
- }
-
test("test table-level statistics for hive tables created in HiveExternalCatalog") {
val textTable = "textTable"
withTable(textTable) {
@@ -442,7 +497,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
test("estimates the size of a test Hive serde tables") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect {
- case relation: CatalogRelation => relation.stats(conf).sizeInBytes
+ case relation: HiveTableRelation => relation.stats(conf).sizeInBytes
}
assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizes(0).equals(BigInt(5812)),
@@ -502,7 +557,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
() => (),
metastoreQuery,
metastoreAnswer,
- implicitly[ClassTag[CatalogRelation]]
+ implicitly[ClassTag[HiveTableRelation]]
)
}
@@ -516,7 +571,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
// Assert src has a size smaller than the threshold.
val sizes = df.queryExecution.analyzed.collect {
- case relation: CatalogRelation => relation.stats(conf).sizeInBytes
+ case relation: HiveTableRelation => relation.stats(conf).sizeInBytes
}
assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold
&& sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
index 031c1a5ec0ec3..9bc832a437c10 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
@@ -65,6 +65,10 @@ class FiltersSuite extends SparkFunSuite with Logging {
(Literal("") === a("varchar", StringType)) :: Nil,
"")
+ filterTest("null-safe equals",
+ (Literal("test") <=> a("stringcol", StringType)) :: Nil,
+ "")
+
filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning",
(a("stringcol", StringType) === Literal("p1\" and q=\"q1")) ::
(Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 7aff49c0fc3b1..7dd4fef193c3f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive.client
-import java.io.{ByteArrayOutputStream, File, PrintStream}
+import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter}
import java.net.URI
import org.apache.hadoop.conf.Configuration
@@ -164,6 +164,15 @@ class VersionsSuite extends SparkFunSuite with Logging {
client.createDatabase(tempDB, ignoreIfExists = true)
}
+ test(s"$version: createDatabase with null description") {
+ withTempDir { tmpDir =>
+ val dbWithNullDesc =
+ CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map())
+ client.createDatabase(dbWithNullDesc, ignoreIfExists = true)
+ assert(client.getDatabase("dbWithNullDesc").description == "")
+ }
+ }
+
test(s"$version: setCurrentDatabase") {
client.setCurrentDatabase("default")
}
@@ -697,6 +706,52 @@ class VersionsSuite extends SparkFunSuite with Logging {
assert(versionSpark.table("t1").collect() === Array(Row(2)))
}
}
+
+ test(s"$version: SPARK-17920: Insert into/overwrite avro table") {
+ withTempDir { dir =>
+ val destTableName = "tab1"
+ val avroSchema =
+ """{
+ |"type": "record",
+ | "name": "test_Record",
+ | "namespace": "ns.avro",
+ | "fields" : [
+ | {"name": "f1", "type": "string"},
+ | {"name": "f2", "type": ["null", "string"]}
+ | ]
+ |}
+ """.stripMargin
+
+ withTable(destTableName) {
+ val schemaFile = new File(dir, "avroSchema.avsc")
+ val writer = new PrintWriter(schemaFile)
+ writer.write(avroSchema)
+ writer.close()
+ val schemaPath = schemaFile.getCanonicalPath
+
+ versionSpark.sql(
+ s"""CREATE TABLE $destTableName
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat'
+ |TBLPROPERTIES ('avro.schema.url' = '$schemaPath')
+ """.stripMargin
+ )
+ val insertStmt = s"INSERT OVERWRITE TABLE $destTableName SELECT 'ABC', 'DEF'"
+ if (version == "0.12") {
+ // Hive 0.12 throws AnalysisException
+ intercept[AnalysisException](versionSpark.sql(insertStmt))
+ } else {
+ val result = versionSpark.sql("SELECT 'ABC', 'DEF'").collect()
+ versionSpark.sql(insertStmt)
+ assert(versionSpark.table(destTableName).collect() === result)
+ versionSpark.sql(s"INSERT INTO TABLE $destTableName SELECT 'ABC', 'DEF'")
+ assert(versionSpark.table(destTableName).collect() === result ++ result)
+ }
+ }
+ }
+ }
// TODO: add more tests.
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 84f915977bd88..3a9e50c7685c0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -1002,6 +1002,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
)
)
}
+
+ test("SPARK-24957: average with decimal followed by aggregation returning wrong result") {
+ val df = Seq(("a", BigDecimal("12.0")),
+ ("a", BigDecimal("12.0")),
+ ("a", BigDecimal("11.9999999988")),
+ ("a", BigDecimal("12.0")),
+ ("a", BigDecimal("12.0")),
+ ("a", BigDecimal("11.9999999988")),
+ ("a", BigDecimal("11.9999999988"))).toDF("text", "number")
+ val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res"))
+ val agg2 = agg1.groupBy($"text").agg(sum($"avg_res"))
+ checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142857143")))
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 3906968aaff10..f4c26256c0ee8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -21,6 +21,8 @@ import java.io.File
import java.net.URI
import org.apache.hadoop.fs.Path
+import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER
+import org.apache.parquet.hadoop.ParquetFileReader
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkException
@@ -30,12 +32,14 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils}
import org.apache.spark.sql.hive.HiveExternalCatalog
+import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET}
import org.apache.spark.sql.hive.orc.OrcFileOperator
import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
// TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite
class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach {
@@ -50,15 +54,28 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA
protected override def generateTable(
catalog: SessionCatalog,
- name: TableIdentifier): CatalogTable = {
+ name: TableIdentifier,
+ isDataSource: Boolean): CatalogTable = {
val storage =
- CatalogStorageFormat(
- locationUri = Some(catalog.defaultTablePath(name)),
- inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"),
- serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"),
- compressed = false,
- properties = Map("serialization.format" -> "1"))
+ if (isDataSource) {
+ val serde = HiveSerDe.sourceToSerDe("parquet")
+ assert(serde.isDefined, "The default format is not Hive compatible")
+ CatalogStorageFormat(
+ locationUri = Some(catalog.defaultTablePath(name)),
+ inputFormat = serde.get.inputFormat,
+ outputFormat = serde.get.outputFormat,
+ serde = serde.get.serde,
+ compressed = false,
+ properties = Map("serialization.format" -> "1"))
+ } else {
+ CatalogStorageFormat(
+ locationUri = Some(catalog.defaultTablePath(name)),
+ inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"),
+ outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"),
+ serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"),
+ compressed = false,
+ properties = Map("serialization.format" -> "1"))
+ }
val metadata = new MetadataBuilder()
.putString("key", "value")
.build()
@@ -71,7 +88,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA
.add("col2", "string")
.add("a", "int")
.add("b", "int"),
- provider = Some("hive"),
+ provider = if (isDataSource) Some("parquet") else Some("hive"),
partitionColumnNames = Seq("a", "b"),
createTime = 0L,
tracksPartitionsInCatalog = true)
@@ -107,6 +124,46 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA
)
}
+ test("alter table: set location") {
+ testSetLocation(isDatasourceTable = false)
+ }
+
+ test("alter table: set properties") {
+ testSetProperties(isDatasourceTable = false)
+ }
+
+ test("alter table: unset properties") {
+ testUnsetProperties(isDatasourceTable = false)
+ }
+
+ test("alter table: set serde") {
+ testSetSerde(isDatasourceTable = false)
+ }
+
+ test("alter table: set serde partition") {
+ testSetSerdePartition(isDatasourceTable = false)
+ }
+
+ test("alter table: change column") {
+ testChangeColumn(isDatasourceTable = false)
+ }
+
+ test("alter table: rename partition") {
+ testRenamePartitions(isDatasourceTable = false)
+ }
+
+ test("alter table: drop partition") {
+ testDropPartitions(isDatasourceTable = false)
+ }
+
+ test("alter table: add partition") {
+ testAddPartitions(isDatasourceTable = false)
+ }
+
+ test("drop table") {
+ testDropTable(isDatasourceTable = false)
+ }
+
}
class HiveDDLSuite
@@ -130,7 +187,7 @@ class HiveDDLSuite
if (dbPath.isEmpty) {
hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier)
} else {
- new Path(new Path(dbPath.get), tableIdentifier.table)
+ new Path(new Path(dbPath.get), tableIdentifier.table).toUri
}
val filesystemPath = new Path(expectedTablePath.toString)
val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf())
@@ -732,7 +789,7 @@ class HiveDDLSuite
checkAnswer(
sql(s"DESC $tabName").select("col_name", "data_type", "comment"),
- Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil
+ Row("a", "int", "test") :: Nil
)
}
}
@@ -1197,6 +1254,14 @@ class HiveDDLSuite
s"CREATE INDEX $indexName ON TABLE $tabName (a) AS 'COMPACT' WITH DEFERRED REBUILD")
val indexTabName =
spark.sessionState.catalog.listTables("default", s"*$indexName*").head.table
+
+ // Even if index tables exist, listTables and getTable APIs should still work
+ checkAnswer(
+ spark.catalog.listTables().toDF(),
+ Row(indexTabName, "default", null, null, false) ::
+ Row(tabName, "default", null, "MANAGED", false) :: Nil)
+ assert(spark.catalog.getTable("default", indexTabName).name === indexTabName)
+
intercept[TableAlreadyExistsException] {
sql(s"CREATE TABLE $indexTabName(b int)")
}
@@ -1376,12 +1441,8 @@ class HiveDDLSuite
sql("INSERT INTO t SELECT 1")
checkAnswer(spark.table("t"), Row(1))
// Check if this is compressed as ZLIB.
- val maybeOrcFile = path.listFiles().find(!_.getName.endsWith(".crc"))
- assert(maybeOrcFile.isDefined)
- val orcFilePath = maybeOrcFile.get.toPath.toString
- val expectedCompressionKind =
- OrcFileOperator.getFileReader(orcFilePath).get.getCompression
- assert("ZLIB" === expectedCompressionKind.name())
+ val maybeOrcFile = path.listFiles().find(_.getName.startsWith("part"))
+ assertCompression(maybeOrcFile, "orc", "ZLIB")
sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2")
val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2"))
@@ -1575,7 +1636,7 @@ class HiveDDLSuite
test("create hive table with a non-existing location") {
withTable("t", "t1") {
withTempPath { dir =>
- spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'")
+ spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '${dir.toURI}'")
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
@@ -1592,7 +1653,7 @@ class HiveDDLSuite
|CREATE TABLE t1(a int, b int)
|USING hive
|PARTITIONED BY(a)
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
@@ -1620,7 +1681,7 @@ class HiveDDLSuite
s"""
|CREATE TABLE t
|USING hive
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
|AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
@@ -1636,7 +1697,7 @@ class HiveDDLSuite
|CREATE TABLE t1
|USING hive
|PARTITIONED BY(a, b)
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
|AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
@@ -1662,21 +1723,21 @@ class HiveDDLSuite
|CREATE TABLE t(a string, `$specialChars` string)
|USING $datasource
|PARTITIONED BY(`$specialChars`)
- |LOCATION '$dir'
+ |LOCATION '${dir.toURI}'
""".stripMargin)
assert(dir.listFiles().isEmpty)
spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1")
val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2"
val partFile = new File(dir, partEscaped)
- assert(partFile.listFiles().length >= 1)
+ assert(partFile.listFiles().nonEmpty)
checkAnswer(spark.table("t"), Row("1", "2") :: Nil)
withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4")
val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4"
val partFile1 = new File(dir, partEscaped1)
- assert(partFile1.listFiles().length >= 1)
+ assert(partFile1.listFiles().nonEmpty)
checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil)
}
}
@@ -1687,15 +1748,22 @@ class HiveDDLSuite
Seq("a b", "a:b", "a%b").foreach { specialChars =>
test(s"hive table: location uri contains $specialChars") {
+ // On Windows, it looks colon in the file name is illegal by default. See
+ // https://support.microsoft.com/en-us/help/289627
+ assume(!Utils.isWindows || specialChars != "a:b")
+
withTable("t") {
withTempDir { dir =>
val loc = new File(dir, specialChars)
loc.mkdir()
+ // The parser does not recognize the backslashes on Windows as they are.
+ // These currently should be escaped.
+ val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\")
spark.sql(
s"""
|CREATE TABLE t(a string)
|USING hive
- |LOCATION '$loc'
+ |LOCATION '$escapedLoc'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
@@ -1718,12 +1786,13 @@ class HiveDDLSuite
withTempDir { dir =>
val loc = new File(dir, specialChars)
loc.mkdir()
+ val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\")
spark.sql(
s"""
|CREATE TABLE t1(a string, b string)
|USING hive
|PARTITIONED BY(b)
- |LOCATION '$loc'
+ |LOCATION '$escapedLoc'
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
@@ -1734,16 +1803,20 @@ class HiveDDLSuite
if (specialChars != "a:b") {
spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1")
val partFile = new File(loc, "b=2")
- assert(partFile.listFiles().length >= 1)
+ assert(partFile.listFiles().nonEmpty)
checkAnswer(spark.table("t1"), Row("1", "2") :: Nil)
spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1")
val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14")
assert(!partFile1.exists())
- val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14")
- assert(partFile2.listFiles().length >= 1)
- checkAnswer(spark.table("t1"),
- Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil)
+
+ if (!Utils.isWindows) {
+ // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows.
+ val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14")
+ assert(partFile2.listFiles().nonEmpty)
+ checkAnswer(spark.table("t1"),
+ Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil)
+ }
} else {
val e = intercept[AnalysisException] {
spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1")
@@ -1867,4 +1940,56 @@ class HiveDDLSuite
}
}
}
+
+ private def assertCompression(maybeFile: Option[File], format: String, compression: String) = {
+ assert(maybeFile.isDefined)
+
+ val actualCompression = format match {
+ case "orc" =>
+ OrcFileOperator.getFileReader(maybeFile.get.toPath.toString).get.getCompression.name
+
+ case "parquet" =>
+ val footer = ParquetFileReader.readFooter(
+ sparkContext.hadoopConfiguration, new Path(maybeFile.get.getPath), NO_FILTER)
+ footer.getBlocks.get(0).getColumns.get(0).getCodec.toString
+ }
+
+ assert(compression === actualCompression)
+ }
+
+ Seq(("orc", "ZLIB"), ("parquet", "GZIP")).foreach { case (fileFormat, compression) =>
+ test(s"SPARK-22158 convertMetastore should not ignore table property - $fileFormat") {
+ withSQLConf(CONVERT_METASTORE_ORC.key -> "true", CONVERT_METASTORE_PARQUET.key -> "true") {
+ withTable("t") {
+ withTempPath { path =>
+ sql(
+ s"""
+ |CREATE TABLE t(id int) USING hive
+ |OPTIONS(fileFormat '$fileFormat', compression '$compression')
+ |LOCATION '${path.toURI}'
+ """.stripMargin)
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(DDLUtils.isHiveTable(table))
+ assert(table.storage.serde.get.contains(fileFormat))
+ assert(table.storage.properties.get("compression") == Some(compression))
+ assert(spark.table("t").collect().isEmpty)
+
+ sql("INSERT INTO t SELECT 1")
+ checkAnswer(spark.table("t"), Row(1))
+ val maybeFile = path.listFiles().find(_.getName.startsWith("part"))
+ assertCompression(maybeFile, fileFormat, compression)
+ }
+ }
+ }
+ }
+ }
+
+ test("load command for non local invalid path validation") {
+ withTable("tbl") {
+ sql("CREATE TABLE tbl(i INT, j STRING)")
+ val e = intercept[AnalysisException](
+ sql("load data inpath '/doesnotexist.csv' into table tbl"))
+ assert(e.message.contains("LOAD DATA input path does not exist"))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 8a37bc3665d32..aa1ca2909074f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -43,11 +43,29 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
test("explain extended command") {
checkKeywordsExist(sql(" explain select * from src where key=123 "),
- "== Physical Plan ==")
+ "== Physical Plan ==",
+ "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
+
checkKeywordsNotExist(sql(" explain select * from src where key=123 "),
"== Parsed Logical Plan ==",
"== Analyzed Logical Plan ==",
- "== Optimized Logical Plan ==")
+ "== Optimized Logical Plan ==",
+ "Owner",
+ "Database",
+ "Created",
+ "Last Access",
+ "Type",
+ "Provider",
+ "Properties",
+ "Statistics",
+ "Location",
+ "Serde Library",
+ "InputFormat",
+ "OutputFormat",
+ "Partition Provider",
+ "Schema"
+ )
+
checkKeywordsExist(sql(" explain extended select * from src where key=123 "),
"== Parsed Logical Plan ==",
"== Analyzed Logical Plan ==",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 90e037e292790..ae64cb3210b53 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -164,16 +164,30 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH
|PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e')
|SELECT v.id
""".stripMargin)
- val plan = sql(
- s"""
- |SELECT * FROM $table
- """.stripMargin).queryExecution.sparkPlan
- val scan = plan.collectFirst {
- case p: HiveTableScanExec => p
- }.get
+ val scan = getHiveTableScanExec(s"SELECT * FROM $table")
val numDataCols = scan.relation.dataCols.length
scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols))
}
}
}
+
+ test("HiveTableScanExec canonicalization for different orders of partition filters") {
+ val table = "hive_tbl_part"
+ withTable(table) {
+ sql(
+ s"""
+ |CREATE TABLE $table (id int)
+ |PARTITIONED BY (a int, b int)
+ """.stripMargin)
+ val scan1 = getHiveTableScanExec(s"SELECT * FROM $table WHERE a = 1 AND b = 2")
+ val scan2 = getHiveTableScanExec(s"SELECT * FROM $table WHERE b = 2 AND a = 1")
+ assert(scan1.sameResult(scan2))
+ }
+ }
+
+ private def getHiveTableScanExec(query: String): HiveTableScanExec = {
+ sql(query).queryExecution.sparkPlan.collectFirst {
+ case p: HiveTableScanExec => p
+ }.get
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 4446af2e75e00..8fcbad58350f4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
@@ -590,6 +591,25 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
}
+
+ test("Call the function registered in the not-current database") {
+ Seq("true", "false").foreach { caseSensitive =>
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
+ withDatabase("dAtABaSe1") {
+ sql("CREATE DATABASE dAtABaSe1")
+ withUserDefinedFunction("dAtABaSe1.test_avg" -> false) {
+ sql(s"CREATE FUNCTION dAtABaSe1.test_avg AS '${classOf[GenericUDAFAverage].getName}'")
+ checkAnswer(sql("SELECT dAtABaSe1.test_avg(1)"), Row(1.0))
+ }
+ val message = intercept[AnalysisException] {
+ sql("SELECT dAtABaSe1.unknownFunc(1)")
+ }.getMessage
+ assert(message.contains("Undefined function: 'unknownFunc'") &&
+ message.contains("nor a permanent function registered in the database 'dAtABaSe1'"))
+ }
+ }
+ }
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala
new file mode 100644
index 0000000000000..bc828877e35ec
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.language.existentials
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.StaticSQLConf._
+import org.apache.spark.sql.types._
+import org.apache.spark.tags.ExtendedHiveTest
+import org.apache.spark.util.Utils
+
+/**
+ * A separate set of DDL tests that uses Hive 2.1 libraries, which behave a little differently
+ * from the built-in ones.
+ */
+@ExtendedHiveTest
+class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with BeforeAndAfterEach
+ with BeforeAndAfterAll {
+
+ // Create a custom HiveExternalCatalog instance with the desired configuration. We cannot
+ // use SparkSession here since there's already an active on managed by the TestHive object.
+ private var catalog = {
+ val warehouse = Utils.createTempDir()
+ val metastore = Utils.createTempDir()
+ metastore.delete()
+ val sparkConf = new SparkConf()
+ .set(SparkLauncher.SPARK_MASTER, "local")
+ .set(WAREHOUSE_PATH.key, warehouse.toURI().toString())
+ .set(CATALOG_IMPLEMENTATION.key, "hive")
+ .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.1")
+ .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven")
+
+ val hadoopConf = new Configuration()
+ hadoopConf.set("hive.metastore.warehouse.dir", warehouse.toURI().toString())
+ hadoopConf.set("javax.jdo.option.ConnectionURL",
+ s"jdbc:derby:;databaseName=${metastore.getAbsolutePath()};create=true")
+ // These options are needed since the defaults in Hive 2.1 cause exceptions with an
+ // empty metastore db.
+ hadoopConf.set("datanucleus.schema.autoCreateAll", "true")
+ hadoopConf.set("hive.metastore.schema.verification", "false")
+
+ new HiveExternalCatalog(sparkConf, hadoopConf)
+ }
+
+ override def afterEach: Unit = {
+ catalog.listTables("default").foreach { t =>
+ catalog.dropTable("default", t, true, false)
+ }
+ spark.sessionState.catalog.reset()
+ }
+
+ override def afterAll(): Unit = {
+ catalog = null
+ }
+
+ test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 int) USING json",
+ StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))),
+ hiveCompatible = false)
+ }
+
+ test("SPARK-21617: ALTER TABLE for Hive-compatible DataSource tables") {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 int) USING parquet",
+ StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))))
+ }
+
+ test("SPARK-21617: ALTER TABLE for Hive tables") {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 int) STORED AS parquet",
+ StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))))
+ }
+
+ test("SPARK-21617: ALTER TABLE with incompatible schema on Hive-compatible table") {
+ val exception = intercept[AnalysisException] {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 string) USING parquet",
+ StructType(Array(StructField("c2", IntegerType))))
+ }
+ assert(exception.getMessage().contains("types incompatible with the existing columns"))
+ }
+
+ private def testAlterTable(
+ tableName: String,
+ createTableStmt: String,
+ updatedSchema: StructType,
+ hiveCompatible: Boolean = true): Unit = {
+ spark.sql(createTableStmt)
+ val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName)
+ catalog.createTable(oldTable, true)
+ catalog.alterTableDataSchema("default", tableName, updatedSchema)
+
+ val updatedTable = catalog.getTable("default", tableName)
+ assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames)
+ }
+
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
index f818e29555468..d91f25a4da013 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
@@ -66,4 +67,28 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
}
}
}
+
+ test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") {
+ withTable("tbl") {
+ spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
+ sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS")
+ val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats
+ assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost")
+
+ val df = sql("SELECT * FROM tbl WHERE p = 1")
+ val sizes1 = df.queryExecution.analyzed.collect {
+ case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes
+ }
+ assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ assert(sizes1(0) == tableStats.get.sizeInBytes)
+
+ val relations = df.queryExecution.optimizedPlan.collect {
+ case relation: LogicalRelation => relation
+ }
+ assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+ val size2 = relations(0).computeStats(conf).sizeInBytes
+ assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
+ assert(size2 < tableStats.get.sizeInBytes)
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 75f3744ff35be..3d1027aecad7d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -20,21 +20,21 @@ package org.apache.spark.sql.hive.execution
import java.io.File
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
-import java.util.Locale
+import java.util.{Locale, Set}
import com.google.common.io.Files
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.TestUtils
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException}
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.hive.HiveUtils
+import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
@@ -454,7 +454,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
case LogicalRelation(r: HadoopFsRelation, _, _) =>
if (!isDataSourceTable) {
fail(
- s"${classOf[CatalogRelation].getCanonicalName} is expected, but found " +
+ s"${classOf[HiveTableRelation].getCanonicalName} is expected, but found " +
s"${HadoopFsRelation.getClass.getCanonicalName}.")
}
userSpecifiedLocation match {
@@ -464,11 +464,11 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
assert(catalogTable.provider.get === format)
- case r: CatalogRelation =>
+ case r: HiveTableRelation =>
if (isDataSourceTable) {
fail(
s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " +
- s"${classOf[CatalogRelation].getCanonicalName}.")
+ s"${classOf[HiveTableRelation].getCanonicalName}.")
}
userSpecifiedLocation match {
case Some(location) =>
@@ -948,7 +948,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") {
sql("CREATE TABLE explodeTest (key bigInt)")
table("explodeTest").queryExecution.analyzed match {
- case SubqueryAlias(_, r: CatalogRelation) => // OK
+ case SubqueryAlias(_, r: HiveTableRelation) => // OK
case _ =>
fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation")
}
@@ -1976,6 +1976,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
}
+ test("Auto alias construction of get_json_object") {
+ val df = Seq(("1", """{"f1": "value1", "f5": 5.23}""")).toDF("key", "jstring")
+ val expectedMsg = "Cannot create a table having a column whose name contains commas " +
+ "in Hive metastore. Table: `default`.`t`; Column: get_json_object(jstring, $.f1)"
+
+ withTable("t") {
+ val e = intercept[AnalysisException] {
+ df.select($"key", functions.get_json_object($"jstring", "$.f1"))
+ .write.format("hive").saveAsTable("t")
+ }.getMessage
+ assert(e.contains(expectedMsg))
+ }
+
+ withTempView("tempView") {
+ withTable("t") {
+ df.createTempView("tempView")
+ val e = intercept[AnalysisException] {
+ sql("CREATE TABLE t AS SELECT key, get_json_object(jstring, '$.f1') FROM tempView")
+ }.getMessage
+ assert(e.contains(expectedMsg))
+ }
+ }
+ }
+
test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") {
withTable("spark_19912") {
Seq(
@@ -1991,4 +2015,83 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4))
}
}
+
+ test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") {
+ val table = "test21721"
+ withTable(table) {
+ val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit")
+ deleteOnExitField.setAccessible(true)
+
+ val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration)
+ val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]]
+
+ val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
+ sql(s"CREATE TABLE $table (key INT, value STRING)")
+ val pathSizeToDeleteOnExit = setOfPath.size()
+
+ (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto(table))
+
+ assert(setOfPath.size() == pathSizeToDeleteOnExit)
+ }
+ }
+
+ Seq("orc", "parquet").foreach { format =>
+ test(s"SPARK-18355 Read data from a hive table with a new column - $format") {
+ val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+
+ Seq("true", "false").foreach { value =>
+ withSQLConf(
+ HiveUtils.CONVERT_METASTORE_ORC.key -> value,
+ HiveUtils.CONVERT_METASTORE_PARQUET.key -> value) {
+ withTempDatabase { db =>
+ client.runSqlHive(
+ s"""
+ |CREATE TABLE $db.t(
+ | click_id string,
+ | search_id string,
+ | uid bigint)
+ |PARTITIONED BY (
+ | ts string,
+ | hour string)
+ |STORED AS $format
+ """.stripMargin)
+
+ client.runSqlHive(
+ s"""
+ |INSERT INTO TABLE $db.t
+ |PARTITION (ts = '98765', hour = '01')
+ |VALUES (12, 2, 12345)
+ """.stripMargin
+ )
+
+ checkAnswer(
+ sql(s"SELECT click_id, search_id, uid, ts, hour FROM $db.t"),
+ Row("12", "2", 12345, "98765", "01"))
+
+ client.runSqlHive(s"ALTER TABLE $db.t ADD COLUMNS (dummy string)")
+
+ checkAnswer(
+ sql(s"SELECT click_id, search_id FROM $db.t"),
+ Row("12", "2"))
+
+ checkAnswer(
+ sql(s"SELECT search_id, click_id FROM $db.t"),
+ Row("2", "12"))
+
+ checkAnswer(
+ sql(s"SELECT search_id FROM $db.t"),
+ Row("2"))
+
+ checkAnswer(
+ sql(s"SELECT dummy, click_id FROM $db.t"),
+ Row(null, "12"))
+
+ checkAnswer(
+ sql(s"SELECT click_id, search_id, uid, dummy, ts, hour FROM $db.t"),
+ Row("12", "2", 12345, null, "98765", "01"))
+ }
+ }
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala
index a20c758a83e71..3f9485dd018b1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala
@@ -232,31 +232,4 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto
Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666)))
// scalastyle:on
}
-
- test("null arguments") {
- checkAnswer(sql("""
- |select p_mfgr, p_name, p_size,
- |sum(null) over(distribute by p_mfgr sort by p_name) as sum,
- |avg(null) over(distribute by p_mfgr sort by p_name) as avg
- |from part
- """.stripMargin),
- sql("""
- |select p_mfgr, p_name, p_size,
- |null as sum,
- |null as avg
- |from part
- """.stripMargin))
- }
-
- test("SPARK-16646: LAST_VALUE(FALSE) OVER ()") {
- checkAnswer(sql("SELECT LAST_VALUE(FALSE) OVER ()"), Row(false))
- checkAnswer(sql("SELECT LAST_VALUE(FALSE, FALSE) OVER ()"), Row(false))
- checkAnswer(sql("SELECT LAST_VALUE(TRUE, TRUE) OVER ()"), Row(true))
- }
-
- test("SPARK-16646: FIRST_VALUE(FALSE) OVER ()") {
- checkAnswer(sql("SELECT FIRST_VALUE(FALSE) OVER ()"), Row(false))
- checkAnswer(sql("SELECT FIRST_VALUE(FALSE, FALSE) OVER ()"), Row(false))
- checkAnswer(sql("SELECT FIRST_VALUE(TRUE, TRUE) OVER ()"), Row(true))
- }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index 8c855730c31f2..60ccd996d6d58 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
+import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.test.TestHive._
@@ -475,7 +475,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
} else {
queryExecution.analyzed.collectFirst {
- case _: CatalogRelation => ()
+ case _: HiveTableRelation => ()
}.getOrElse {
fail(s"Expecting no conversion from orc to data sources, " +
s"but got:\n$queryExecution")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 6bfb88c0c1af5..a562de47b9109 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -22,9 +22,12 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -197,7 +200,7 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
}
}
-class OrcSourceSuite extends OrcSuite {
+class OrcSourceSuite extends OrcSuite with SQLTestUtils{
override def beforeAll(): Unit = {
super.beforeAll()
@@ -250,4 +253,31 @@ class OrcSourceSuite extends OrcSuite {
)).get.toString
}
}
+
+ test("SPARK-22972: hive orc source") {
+ val tableName = "normal_orc_as_source_hive"
+ withTable(tableName) {
+ spark.sql(
+ s"""
+ |CREATE TABLE $tableName
+ |USING org.apache.spark.sql.hive.orc
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
+ |)
+ """.stripMargin)
+
+ val tableMetadata = spark.sessionState.catalog.getTableMetadata(
+ TableIdentifier(tableName))
+ assert(tableMetadata.storage.inputFormat ==
+ Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"))
+ assert(tableMetadata.storage.outputFormat ==
+ Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
+ assert(tableMetadata.storage.serde ==
+ Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
+ assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc")
+ .equals(HiveSerDe.sourceToSerDe("orc")))
+ assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.orc")
+ .equals(HiveSerDe.sourceToSerDe("orc")))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 23f21e6b9931e..303884da19f09 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
+import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
@@ -812,7 +812,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
}
} else {
queryExecution.analyzed.collectFirst {
- case _: CatalogRelation =>
+ case _: HiveTableRelation =>
}.getOrElse {
fail(s"Expecting no conversion from parquet to data sources, " +
s"but got:\n$queryExecution")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 9f4009bfe402a..60a4638f610b3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -103,7 +103,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
// `Cast`ed values are always of internal types (e.g. UTF8String instead of String)
Cast(Literal(value), dataType).eval()
})
- }.filter(predicate).map(projection)
+ }.filter(predicate.eval).map(projection)
// Appends partition values
val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
diff --git a/streaming/pom.xml b/streaming/pom.xml
index de1be9c13e05f..0be5e4fc47193 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../pom.xml
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 5cbad8bf3ce6e..0f59878cc4019 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -51,14 +51,20 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time)
"spark.yarn.app.id",
"spark.yarn.app.attemptId",
"spark.driver.host",
+ "spark.driver.bindAddress",
"spark.driver.port",
"spark.master",
+ "spark.yarn.jars",
"spark.yarn.keytab",
"spark.yarn.principal",
+ "spark.yarn.credentials.file",
+ "spark.yarn.credentials.renewalTime",
+ "spark.yarn.credentials.updateTime",
"spark.ui.filters")
val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs)
.remove("spark.driver.host")
+ .remove("spark.driver.bindAddress")
.remove("spark.driver.port")
val newReloadConf = new SparkConf(loadDefaults = true)
propertiesToReload.foreach { prop =>
@@ -206,9 +212,6 @@ class CheckpointWriter(
if (latestCheckpointTime == null || latestCheckpointTime < checkpointTime) {
latestCheckpointTime = checkpointTime
}
- if (fs == null) {
- fs = new Path(checkpointDir).getFileSystem(hadoopConf)
- }
var attempts = 0
val startTime = System.currentTimeMillis()
val tempFile = new Path(checkpointDir, "temp")
@@ -228,7 +231,9 @@ class CheckpointWriter(
attempts += 1
try {
logInfo(s"Saving checkpoint for time $checkpointTime to file '$checkpointFile'")
-
+ if (fs == null) {
+ fs = new Path(checkpointDir).getFileSystem(hadoopConf)
+ }
// Write checkpoint to temp file
fs.delete(tempFile, true) // just in case it exists
val fos = fs.create(tempFile)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index 5d9a8ac0d9297..dacff69d55dd2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker(
getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
}
- // Insert the recovered block-to-batch allocations and clear the queue of received blocks
- // (when the blocks were originally allocated to the batch, the queue must have been cleared).
+ // Insert the recovered block-to-batch allocations and removes them from queue of
+ // received blocks.
def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) {
logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " +
s"${allocatedBlocks.streamIdToAllocatedBlocks}")
- streamIdToUnallocatedBlockQueues.values.foreach { _.clear() }
+ allocatedBlocks.streamIdToAllocatedBlocks.foreach {
+ case (streamId, allocatedBlocksInStream) =>
+ getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet)
+ }
timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
lastAllocatedBatchTime = batchTime
}
@@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker(
}
/** Write an update to the tracker to the write ahead log */
- private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = {
+ private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = {
if (isWriteAheadLogEnabled) {
logTrace(s"Writing record: $record")
try {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
index f55af6a5cc358..69e15655ad790 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
@@ -304,7 +304,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
}
def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized {
- val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse {
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val batchTime =
+ Option(SparkUIUtils.stripXSS(request.getParameter("id"))).map(id => Time(id.toLong))
+ .getOrElse {
throw new IllegalArgumentException(s"Missing id parameter")
}
val formattedBatchTime =
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 845f554308c43..1e5f18797e152 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -189,7 +189,9 @@ private[streaming] class FileBasedWriteAheadLog(
val f = Future { deleteFile(logInfo) }(executionContext)
if (waitForCompletion) {
import scala.concurrent.duration._
+ // scalastyle:off awaitready
Await.ready(f, 1 second)
+ // scalastyle:on awaitready
}
} catch {
case e: RejectedExecutionException =>
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index 107c3f5dcc08d..4fa236bd39663 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
-import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _}
import org.apache.spark.streaming.util._
import org.apache.spark.streaming.util.WriteAheadLogSuite._
import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
@@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite
receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
}
+ test("recovery with write ahead logs should remove only allocated blocks from received queue") {
+ val manualClock = new ManualClock
+ val batchTime = manualClock.getTimeMillis()
+
+ val tracker1 = createTracker(clock = manualClock)
+ tracker1.isWriteAheadLogEnabled should be (true)
+
+ val allocatedBlockInfos = generateBlockInfos()
+ val unallocatedBlockInfos = generateBlockInfos()
+ val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos
+ receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) }
+ val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos))
+ tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))
+ tracker1.stop()
+
+ val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true)
+ tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks
+ tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos
+ tracker2.stop()
+ }
+
test("recovery and cleanup with write ahead logs") {
val manualClock = new ManualClock
// Set the time increment level to twice the rotation interval so that every increment creates
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
index b70383ecde4d8..4f41b9d0a0b3c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
@@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.language.reflectiveCalls
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers._
@@ -202,21 +201,17 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter {
test("block push errors are reported") {
val listener = new TestBlockGeneratorListener {
- @volatile var errorReported = false
override def onPushBlock(
blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
throw new SparkException("test")
}
- override def onError(message: String, throwable: Throwable): Unit = {
- errorReported = true
- }
}
blockGenerator = new BlockGenerator(listener, 0, conf)
blockGenerator.start()
- assert(listener.errorReported === false)
+ assert(listener.onErrorCalled === false)
blockGenerator.addData(1)
eventually(timeout(1 second), interval(10 milliseconds)) {
- assert(listener.errorReported === true)
+ assert(listener.onErrorCalled === true)
}
blockGenerator.stop()
}
@@ -243,12 +238,15 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter {
@volatile var onGenerateBlockCalled = false
@volatile var onAddDataCalled = false
@volatile var onPushBlockCalled = false
+ @volatile var onErrorCalled = false
override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
pushedData.addAll(arrayBuffer.asJava)
onPushBlockCalled = true
}
- override def onError(message: String, throwable: Throwable): Unit = {}
+ override def onError(message: String, throwable: Throwable): Unit = {
+ onErrorCalled = true
+ }
override def onGenerateBlock(blockId: StreamBlockId): Unit = {
onGenerateBlockCalled = true
}
diff --git a/tools/pom.xml b/tools/pom.xml
index 938ba2f6ac201..1c6a26734203c 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.3-SNAPSHOT
../pom.xml